OpenCLKernels.cpp 478 KB
Newer Older
1
2
3
4
5
6
7
8
/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
 * This is part of the OpenMM molecular simulation toolkit originating from   *
 * Simbios, the NIH National Center for Physics-Based Simulation of           *
 * Biological Structures at Stanford, funded under the NIH Roadmap for        *
 * Medical Research, grant U54 GM072970. See https://simtk.org.               *
 *                                                                            *
9
 * Portions copyright (c) 2008-2019 Stanford University and the Authors.      *
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
 * Authors: Peter Eastman                                                     *
 * Contributors:                                                              *
 *                                                                            *
 * This program is free software: you can redistribute it and/or modify       *
 * it under the terms of the GNU Lesser General Public License as published   *
 * by the Free Software Foundation, either version 3 of the License, or       *
 * (at your option) any later version.                                        *
 *                                                                            *
 * This program is distributed in the hope that it will be useful,            *
 * but WITHOUT ANY WARRANTY; without even the implied warranty of             *
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the              *
 * GNU Lesser General Public License for more details.                        *
 *                                                                            *
 * You should have received a copy of the GNU Lesser General Public License   *
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.      *
 * -------------------------------------------------------------------------- */

#include "OpenCLKernels.h"
28
#include "OpenCLForceInfo.h"
29
30
#include "openmm/LangevinIntegrator.h"
#include "openmm/Context.h"
31
#include "openmm/internal/AndersenThermostatImpl.h"
32
#include "openmm/internal/CMAPTorsionForceImpl.h"
33
#include "openmm/internal/ContextImpl.h"
34
#include "openmm/internal/CustomCentroidBondForceImpl.h"
35
#include "openmm/internal/CustomCompoundBondForceImpl.h"
36
#include "openmm/internal/CustomHbondForceImpl.h"
37
#include "openmm/internal/CustomManyParticleForceImpl.h"
38
#include "openmm/internal/CustomNonbondedForceImpl.h"
39
#include "openmm/internal/NonbondedForceImpl.h"
40
#include "openmm/internal/OSRngSeed.h"
Peter Eastman's avatar
Peter Eastman committed
41
#include "OpenCLBondedUtilities.h"
42
#include "OpenCLExpressionUtilities.h"
43
#include "OpenCLIntegrationUtilities.h"
44
#include "OpenCLNonbondedUtilities.h"
45
#include "OpenCLKernelSources.h"
46
#include "lepton/CustomFunction.h"
47
#include "lepton/ExpressionTreeNode.h"
48
#include "lepton/Operation.h"
49
50
#include "lepton/Parser.h"
#include "lepton/ParsedExpression.h"
51
#include "ReferenceTabulatedFunction.h"
52
53
#include "SimTKOpenMMRealType.h"
#include "SimTKOpenMMUtilities.h"
peastman's avatar
peastman committed
54
#include "jama_eig.h"
55
#include <algorithm>
56
#include <assert.h>
57
#include <cmath>
58
#include <iterator>
59
#include <set>
60
61
62

using namespace OpenMM;
using namespace std;
63
using namespace Lepton;
64

65
66
67
68
69
70
71
static void setPosqCorrectionArg(OpenCLContext& cl, cl::Kernel& kernel, int index) {
    if (cl.getUseMixedPrecision())
        kernel.setArg<cl::Buffer>(index, cl.getPosqCorrection().getDeviceBuffer());
    else
        kernel.setArg<void*>(index, NULL);
}

72
73
74
75
76
77
78
static void setPeriodicBoxSizeArg(OpenCLContext& cl, cl::Kernel& kernel, int index) {
    if (cl.getUseDoublePrecision())
        kernel.setArg<mm_double4>(index, cl.getPeriodicBoxSizeDouble());
    else
        kernel.setArg<mm_float4>(index, cl.getPeriodicBoxSize());
}

79
static void setPeriodicBoxArgs(OpenCLContext& cl, cl::Kernel& kernel, int index) {
80
    if (cl.getUseDoublePrecision()) {
81
82
        kernel.setArg<mm_double4>(index++, cl.getPeriodicBoxSizeDouble());
        kernel.setArg<mm_double4>(index++, cl.getInvPeriodicBoxSizeDouble());
83
84
85
86
87
        kernel.setArg<mm_double4>(index++, cl.getPeriodicBoxVecXDouble());
        kernel.setArg<mm_double4>(index++, cl.getPeriodicBoxVecYDouble());
        kernel.setArg<mm_double4>(index, cl.getPeriodicBoxVecZDouble());
    }
    else {
88
89
        kernel.setArg<mm_float4>(index++, cl.getPeriodicBoxSize());
        kernel.setArg<mm_float4>(index++, cl.getInvPeriodicBoxSize());
90
91
92
93
        kernel.setArg<mm_float4>(index++, cl.getPeriodicBoxVecX());
        kernel.setArg<mm_float4>(index++, cl.getPeriodicBoxVecY());
        kernel.setArg<mm_float4>(index, cl.getPeriodicBoxVecZ());
    }
94
95
}

96
97
98
99
100
101
102
static bool isZeroExpression(const Lepton::ParsedExpression& expression) {
    const Lepton::Operation& op = expression.getRootNode().getOperation();
    if (op.getId() != Lepton::Operation::CONSTANT)
        return false;
    return (dynamic_cast<const Lepton::Operation::Constant&>(op).getValue() == 0.0);
}

103
104
105
106
static bool usesVariable(const Lepton::ExpressionTreeNode& node, const string& variable) {
    const Lepton::Operation& op = node.getOperation();
    if (op.getId() == Lepton::Operation::VARIABLE && op.getName() == variable)
        return true;
peastman's avatar
peastman committed
107
108
    for (auto& child : node.getChildren())
        if (usesVariable(child, variable))
109
110
111
112
113
114
115
116
            return true;
    return false;
}

static bool usesVariable(const Lepton::ParsedExpression& expression, const string& variable) {
    return usesVariable(expression.getRootNode(), variable);
}

117
118
119
120
static pair<ExpressionTreeNode, string> makeVariable(const string& name, const string& value) {
    return make_pair(ExpressionTreeNode(new Operation::Variable(name)), value);
}

121
122
123
124
125
126
127
128
129
static void replaceFunctionsInExpression(map<string, CustomFunction*>& functions, ExpressionProgram& expression) {
    for (int i = 0; i < expression.getNumOperations(); i++) {
        if (expression.getOperation(i).getId() == Operation::CUSTOM) {
            const Operation::Custom& op = dynamic_cast<const Operation::Custom&>(expression.getOperation(i));
            expression.setOperation(i, new Operation::Custom(op.getName(), functions[op.getName()]->clone(), op.getDerivOrder()));
        }
    }
}

130
void OpenCLCalcForcesAndEnergyKernel::initialize(const System& system) {
131
132
}

133
void OpenCLCalcForcesAndEnergyKernel::beginComputation(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) {
134
    cl.setForcesValid(true);
135
    cl.clearAutoclearBuffers();
peastman's avatar
peastman committed
136
137
    for (auto computation : cl.getPreComputations())
        computation->computeForceAndEnergy(includeForces, includeEnergy, groups);
138
    OpenCLNonbondedUtilities& nb = cl.getNonbondedUtilities();
139
    cl.setComputeForceCount(cl.getComputeForceCount()+1);
140
    nb.prepareInteractions(groups);
141
    map<string, double>& derivs = cl.getEnergyParamDerivWorkspace();
peastman's avatar
peastman committed
142
143
    for (auto& param : context.getParameters())
        derivs[param.first] = 0;
144
145
}

146
double OpenCLCalcForcesAndEnergyKernel::finishComputation(ContextImpl& context, bool includeForces, bool includeEnergy, int groups, bool& valid) {
147
    cl.getBondedUtilities().computeInteractions(groups);
148
    cl.getNonbondedUtilities().computeInteractions(groups, includeForces, includeEnergy);
149
    double sum = 0.0;
peastman's avatar
peastman committed
150
151
    for (auto computation : cl.getPostComputations())
        sum += computation->computeForceAndEnergy(includeForces, includeEnergy, groups);
152
    cl.reduceForces();
153
    cl.getIntegrationUtilities().distributeForcesFromVirtualSites();
Peter Eastman's avatar
Peter Eastman committed
154
155
    if (includeEnergy)
        sum += cl.reduceEnergy();
156
157
    if (!cl.getForcesValid())
        valid = false;
158
    return sum;
159
160
}

161
void OpenCLUpdateStateDataKernel::initialize(const System& system) {
162
163
}

164
double OpenCLUpdateStateDataKernel::getTime(const ContextImpl& context) const {
165
    return cl.getTime();
166
167
}

168
void OpenCLUpdateStateDataKernel::setTime(ContextImpl& context, double time) {
169
    vector<OpenCLContext*>& contexts = cl.getPlatformData().contexts;
peastman's avatar
peastman committed
170
171
    for (auto ctx : contexts)
        ctx->setTime(time);
172
173
}

peastman's avatar
peastman committed
174
175
176
177
178
179
180
181
182
183
184
185
186
void OpenCLUpdateStateDataKernel::getPositions(ContextImpl& context, vector<Vec3>& positions) {
    int numParticles = context.getSystem().getNumParticles();
    positions.resize(numParticles);
    vector<mm_float4> posCorrection;
    if (cl.getUseDoublePrecision()) {
        mm_double4* posq = (mm_double4*) cl.getPinnedBuffer();
        cl.getPosq().download(posq);
    }
    else if (cl.getUseMixedPrecision()) {
        mm_float4* posq = (mm_float4*) cl.getPinnedBuffer();
        cl.getPosq().download(posq, false);
        posCorrection.resize(numParticles);
        cl.getPosqCorrection().download(posCorrection);
187
    }
peastman's avatar
peastman committed
188
189
190
    else {
        mm_float4* posq = (mm_float4*) cl.getPinnedBuffer();
        cl.getPosq().download(posq);
191
    }
peastman's avatar
peastman committed
192
193
194
195
    
    // Filling in the output array is done in parallel for speed.
    
    cl.getPlatformData().threads.execute([&] (ThreadPool& threads, int threadIndex) {
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
        // Compute the position of each particle to return to the user.  This is done in parallel for speed.
        
        const vector<int>& order = cl.getAtomIndex();
        int numParticles = cl.getNumAtoms();
        Vec3 boxVectors[3];
        cl.getPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
        int numThreads = threads.getNumThreads();
        int start = threadIndex*numParticles/numThreads;
        int end = (threadIndex+1)*numParticles/numThreads;
        if (cl.getUseDoublePrecision()) {
            mm_double4* posq = (mm_double4*) cl.getPinnedBuffer();
            for (int i = start; i < end; ++i) {
                mm_double4 pos = posq[i];
                mm_int4 offset = cl.getPosCellOffsets()[i];
                positions[order[i]] = Vec3(pos.x, pos.y, pos.z)-boxVectors[0]*offset.x-boxVectors[1]*offset.y-boxVectors[2]*offset.z;
            }
        }
        else if (cl.getUseMixedPrecision()) {
            mm_float4* posq = (mm_float4*) cl.getPinnedBuffer();
            for (int i = start; i < end; ++i) {
                mm_float4 pos1 = posq[i];
                mm_float4 pos2 = posCorrection[i];
                mm_int4 offset = cl.getPosCellOffsets()[i];
                positions[order[i]] = Vec3((double)pos1.x+(double)pos2.x, (double)pos1.y+(double)pos2.y, (double)pos1.z+(double)pos2.z)-boxVectors[0]*offset.x-boxVectors[1]*offset.y-boxVectors[2]*offset.z;
            }
        }
        else {
            mm_float4* posq = (mm_float4*) cl.getPinnedBuffer();
            for (int i = start; i < end; ++i) {
                mm_float4 pos = posq[i];
                mm_int4 offset = cl.getPosCellOffsets()[i];
                positions[order[i]] = Vec3(pos.x, pos.y, pos.z)-boxVectors[0]*offset.x-boxVectors[1]*offset.y-boxVectors[2]*offset.z;
            }
        }
peastman's avatar
peastman committed
230
    });
231
    cl.getPlatformData().threads.waitForThreads();
232
233
}

Peter Eastman's avatar
Peter Eastman committed
234
void OpenCLUpdateStateDataKernel::setPositions(ContextImpl& context, const vector<Vec3>& positions) {
235
    const vector<cl_int>& order = cl.getAtomIndex();
236
    int numParticles = context.getSystem().getNumParticles();
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
    if (cl.getUseDoublePrecision()) {
        mm_double4* posq = (mm_double4*) cl.getPinnedBuffer();
        cl.getPosq().download(posq);
        for (int i = 0; i < numParticles; ++i) {
            mm_double4& pos = posq[i];
            const Vec3& p = positions[order[i]];
            pos.x = p[0];
            pos.y = p[1];
            pos.z = p[2];
        }
        for (int i = numParticles; i < cl.getPaddedNumAtoms(); i++)
            posq[i] = mm_double4(0.0, 0.0, 0.0, 0.0);
        cl.getPosq().upload(posq);
    }
    else {
        mm_float4* posq = (mm_float4*) cl.getPinnedBuffer();
        cl.getPosq().download(posq);
        for (int i = 0; i < numParticles; ++i) {
            mm_float4& pos = posq[i];
            const Vec3& p = positions[order[i]];
            pos.x = (cl_float) p[0];
            pos.y = (cl_float) p[1];
            pos.z = (cl_float) p[2];
        }
        for (int i = numParticles; i < cl.getPaddedNumAtoms(); i++)
            posq[i] = mm_float4(0.0f, 0.0f, 0.0f, 0.0f);
        cl.getPosq().upload(posq);
    }
    if (cl.getUseMixedPrecision()) {
        mm_float4* posCorrection = (mm_float4*) cl.getPinnedBuffer();
        for (int i = 0; i < numParticles; ++i) {
            mm_float4& c = posCorrection[i];
            const Vec3& p = positions[order[i]];
            c.x = (cl_float) (p[0]-(cl_float)p[0]);
            c.y = (cl_float) (p[1]-(cl_float)p[1]);
            c.z = (cl_float) (p[2]-(cl_float)p[2]);
            c.w = 0;
        }
        for (int i = numParticles; i < cl.getPaddedNumAtoms(); i++)
            posCorrection[i] = mm_float4(0.0f, 0.0f, 0.0f, 0.0f);
        cl.getPosqCorrection().upload(posCorrection);
    }
peastman's avatar
peastman committed
279
280
    for (auto& offset : cl.getPosCellOffsets())
        offset = mm_int4(0, 0, 0, 0);
281
    cl.reorderAtoms();
282
283
}

Peter Eastman's avatar
Peter Eastman committed
284
void OpenCLUpdateStateDataKernel::getVelocities(ContextImpl& context, vector<Vec3>& velocities) {
285
    const vector<cl_int>& order = cl.getAtomIndex();
286
287
    int numParticles = context.getSystem().getNumParticles();
    velocities.resize(numParticles);
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
    if (cl.getUseDoublePrecision() || cl.getUseMixedPrecision()) {
        mm_double4* velm = (mm_double4*) cl.getPinnedBuffer();
        cl.getVelm().download(velm);
        for (int i = 0; i < numParticles; ++i) {
            mm_double4 vel = velm[i];
            mm_int4 offset = cl.getPosCellOffsets()[i];
            velocities[order[i]] = Vec3(vel.x, vel.y, vel.z);
        }
    }
    else {
        mm_float4* velm = (mm_float4*) cl.getPinnedBuffer();
        cl.getVelm().download(velm);
        for (int i = 0; i < numParticles; ++i) {
            mm_float4 vel = velm[i];
            mm_int4 offset = cl.getPosCellOffsets()[i];
            velocities[order[i]] = Vec3(vel.x, vel.y, vel.z);
        }
305
306
307
    }
}

Peter Eastman's avatar
Peter Eastman committed
308
void OpenCLUpdateStateDataKernel::setVelocities(ContextImpl& context, const vector<Vec3>& velocities) {
309
    const vector<cl_int>& order = cl.getAtomIndex();
310
    int numParticles = context.getSystem().getNumParticles();
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
    if (cl.getUseDoublePrecision() || cl.getUseMixedPrecision()) {
        mm_double4* velm = (mm_double4*) cl.getPinnedBuffer();
        cl.getVelm().download(velm);
        for (int i = 0; i < numParticles; ++i) {
            mm_double4& vel = velm[i];
            const Vec3& p = velocities[order[i]];
            vel.x = p[0];
            vel.y = p[1];
            vel.z = p[2];
        }
        for (int i = numParticles; i < cl.getPaddedNumAtoms(); i++)
            velm[i] = mm_double4(0.0, 0.0, 0.0, 0.0);
        cl.getVelm().upload(velm);
    }
    else {
        mm_float4* velm = (mm_float4*) cl.getPinnedBuffer();
        cl.getVelm().download(velm);
        for (int i = 0; i < numParticles; ++i) {
            mm_float4& vel = velm[i];
            const Vec3& p = velocities[order[i]];
            vel.x = p[0];
            vel.y = p[1];
            vel.z = p[2];
        }
        for (int i = numParticles; i < cl.getPaddedNumAtoms(); i++)
            velm[i] = mm_float4(0.0f, 0.0f, 0.0f, 0.0f);
        cl.getVelm().upload(velm);
    }
339
340
}

Peter Eastman's avatar
Peter Eastman committed
341
void OpenCLUpdateStateDataKernel::getForces(ContextImpl& context, vector<Vec3>& forces) {
342
    const vector<cl_int>& order = cl.getAtomIndex();
343
344
    int numParticles = context.getSystem().getNumParticles();
    forces.resize(numParticles);
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
    if (cl.getUseDoublePrecision()) {
        mm_double4* force = (mm_double4*) cl.getPinnedBuffer();
        cl.getForce().download(force);
        for (int i = 0; i < numParticles; ++i) {
            mm_double4 f = force[i];
            forces[order[i]] = Vec3(f.x, f.y, f.z);
        }
    }
    else {
        mm_float4* force = (mm_float4*) cl.getPinnedBuffer();
        cl.getForce().download(force);
        for (int i = 0; i < numParticles; ++i) {
            mm_float4 f = force[i];
            forces[order[i]] = Vec3(f.x, f.y, f.z);
        }
360
361
362
    }
}

363
void OpenCLUpdateStateDataKernel::getEnergyParameterDerivatives(ContextImpl& context, map<string, double>& derivs) {
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
    const vector<string>& paramDerivNames = cl.getEnergyParamDerivNames();
    int numDerivs = paramDerivNames.size();
    if (numDerivs == 0)
        return;
    derivs = cl.getEnergyParamDerivWorkspace();
    OpenCLArray& derivArray = cl.getEnergyParamDerivBuffer();
    if (cl.getUseDoublePrecision() || cl.getUseMixedPrecision()) {
        vector<double> derivBuffers;
        derivArray.download(derivBuffers);
        for (int i = numDerivs; i < derivArray.getSize(); i += numDerivs)
            for (int j = 0; j < numDerivs; j++)
                derivBuffers[j] += derivBuffers[i+j];
        for (int i = 0; i < numDerivs; i++)
            derivs[paramDerivNames[i]] += derivBuffers[i];
    }
    else {
        vector<float> derivBuffers;
        derivArray.download(derivBuffers);
        for (int i = numDerivs; i < derivArray.getSize(); i += numDerivs)
            for (int j = 0; j < numDerivs; j++)
                derivBuffers[j] += derivBuffers[i+j];
        for (int i = 0; i < numDerivs; i++)
            derivs[paramDerivNames[i]] += derivBuffers[i];
    }
388
389
}

390
void OpenCLUpdateStateDataKernel::getPeriodicBoxVectors(ContextImpl& context, Vec3& a, Vec3& b, Vec3& c) const {
391
    cl.getPeriodicBoxVectors(a, b, c);
392
393
}

394
void OpenCLUpdateStateDataKernel::setPeriodicBoxVectors(ContextImpl& context, const Vec3& a, const Vec3& b, const Vec3& c) {
395
    vector<OpenCLContext*>& contexts = cl.getPlatformData().contexts;
396
397
398
399
400

    // If any particles have been wrapped to the first periodic box, we need to unwrap them
    // to avoid changing their positions.

    vector<Vec3> positions;
peastman's avatar
peastman committed
401
    for (auto offset : cl.getPosCellOffsets()) {
402
403
404
405
406
407
408
409
        if (offset.x != 0 || offset.y != 0 || offset.z != 0) {
            getPositions(context, positions);
            break;
        }
    }
    
    // Update the vectors.

peastman's avatar
peastman committed
410
411
    for (auto ctx : contexts)
        ctx->setPeriodicBoxVectors(a, b, c);
412
413
    if (positions.size() > 0)
        setPositions(context, positions);
414
415
}

Peter Eastman's avatar
Peter Eastman committed
416
void OpenCLUpdateStateDataKernel::createCheckpoint(ContextImpl& context, ostream& stream) {
417
    int version = 3;
Peter Eastman's avatar
Peter Eastman committed
418
    stream.write((char*) &version, sizeof(int));
419
420
    int precision = (cl.getUseDoublePrecision() ? 2 : cl.getUseMixedPrecision() ? 1 : 0);
    stream.write((char*) &precision, sizeof(int));
Peter Eastman's avatar
Peter Eastman committed
421
422
    double time = cl.getTime();
    stream.write((char*) &time, sizeof(double));
Peter Eastman's avatar
Peter Eastman committed
423
424
    int stepCount = cl.getStepCount();
    stream.write((char*) &stepCount, sizeof(int));
425
426
    int stepsSinceReorder = cl.getStepsSinceReorder();
    stream.write((char*) &stepsSinceReorder, sizeof(int));
427
    char* buffer = (char*) cl.getPinnedBuffer();
428
429
430
431
432
433
434
435
    cl.getPosq().download(buffer);
    stream.write(buffer, cl.getPosq().getSize()*cl.getPosq().getElementSize());
    if (cl.getUseMixedPrecision()) {
        cl.getPosqCorrection().download(buffer);
        stream.write(buffer, cl.getPosqCorrection().getSize()*cl.getPosqCorrection().getElementSize());
    }
    cl.getVelm().download(buffer);
    stream.write(buffer, cl.getVelm().getSize()*cl.getVelm().getElementSize());
436
    stream.write((char*) &cl.getAtomIndex()[0], sizeof(cl_int)*cl.getAtomIndex().size());
Peter Eastman's avatar
Peter Eastman committed
437
    stream.write((char*) &cl.getPosCellOffsets()[0], sizeof(mm_int4)*cl.getPosCellOffsets().size());
438
439
440
    Vec3 boxVectors[3];
    cl.getPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
    stream.write((char*) boxVectors, 3*sizeof(Vec3));
Peter Eastman's avatar
Peter Eastman committed
441
    cl.getIntegrationUtilities().createCheckpoint(stream);
Peter Eastman's avatar
Peter Eastman committed
442
    SimTKOpenMMUtilities::createCheckpoint(stream);
Peter Eastman's avatar
Peter Eastman committed
443
444
445
446
447
}

void OpenCLUpdateStateDataKernel::loadCheckpoint(ContextImpl& context, istream& stream) {
    int version;
    stream.read((char*) &version, sizeof(int));
448
    if (version != 3)
Peter Eastman's avatar
Peter Eastman committed
449
        throw OpenMMException("Checkpoint was created with a different version of OpenMM");
450
451
452
453
454
    int precision;
    stream.read((char*) &precision, sizeof(int));
    int expectedPrecision = (cl.getUseDoublePrecision() ? 2 : cl.getUseMixedPrecision() ? 1 : 0);
    if (precision != expectedPrecision)
        throw OpenMMException("Checkpoint was created with a different numeric precision");
Peter Eastman's avatar
Peter Eastman committed
455
456
    double time;
    stream.read((char*) &time, sizeof(double));
457
    int stepCount, stepsSinceReorder;
Peter Eastman's avatar
Peter Eastman committed
458
    stream.read((char*) &stepCount, sizeof(int));
459
    stream.read((char*) &stepsSinceReorder, sizeof(int));
Peter Eastman's avatar
Peter Eastman committed
460
    vector<OpenCLContext*>& contexts = cl.getPlatformData().contexts;
peastman's avatar
peastman committed
461
462
463
464
    for (auto ctx : contexts) {
        ctx->setTime(time);
        ctx->setStepCount(stepCount);
        ctx->setStepsSinceReorder(stepsSinceReorder);
Peter Eastman's avatar
Peter Eastman committed
465
    }
466
    char* buffer = (char*) cl.getPinnedBuffer();
467
    stream.read(buffer, cl.getPosq().getSize()*cl.getPosq().getElementSize());
468
    cl.getPosq().upload(buffer);
469
470
471
472
473
    if (cl.getUseMixedPrecision()) {
        stream.read(buffer, cl.getPosqCorrection().getSize()*cl.getPosqCorrection().getElementSize());
        cl.getPosqCorrection().upload(buffer);
    }
    stream.read(buffer, cl.getVelm().getSize()*cl.getVelm().getElementSize());
474
475
476
    cl.getVelm().upload(buffer);
    stream.read((char*) &cl.getAtomIndex()[0], sizeof(cl_int)*cl.getAtomIndex().size());
    cl.getAtomIndexArray().upload(cl.getAtomIndex());
Peter Eastman's avatar
Peter Eastman committed
477
    stream.read((char*) &cl.getPosCellOffsets()[0], sizeof(mm_int4)*cl.getPosCellOffsets().size());
478
479
    Vec3 boxVectors[3];
    stream.read((char*) &boxVectors, 3*sizeof(Vec3));
peastman's avatar
peastman committed
480
481
    for (auto ctx : contexts)
        ctx->setPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
Peter Eastman's avatar
Peter Eastman committed
482
    cl.getIntegrationUtilities().loadCheckpoint(stream);
Peter Eastman's avatar
Peter Eastman committed
483
    SimTKOpenMMUtilities::loadCheckpoint(stream);
peastman's avatar
peastman committed
484
485
    for (auto listener : cl.getReorderListeners())
        listener->execute();
Peter Eastman's avatar
Peter Eastman committed
486
487
}

488
489
490
491
void OpenCLApplyConstraintsKernel::initialize(const System& system) {
}

void OpenCLApplyConstraintsKernel::apply(ContextImpl& context, double tol) {
492
493
494
    if (!hasInitializedKernel) {
        hasInitializedKernel = true;
        map<string, string> defines;
495
        defines["NUM_ATOMS"] = cl.intToString(cl.getNumAtoms());
496
497
498
        cl::Program program = cl.createProgram(OpenCLKernelSources::constraints, defines);
        applyDeltasKernel = cl::Kernel(program, "applyPositionDeltas");
        applyDeltasKernel.setArg<cl::Buffer>(0, cl.getPosq().getDeviceBuffer());
499
500
        setPosqCorrectionArg(cl, applyDeltasKernel, 1);
        applyDeltasKernel.setArg<cl::Buffer>(2, cl.getIntegrationUtilities().getPosDelta().getDeviceBuffer());
501
502
503
504
505
506
    }
    OpenCLIntegrationUtilities& integration = cl.getIntegrationUtilities();
    cl.clearBuffer(integration.getPosDelta());
    integration.applyConstraints(tol);
    cl.executeKernel(applyDeltasKernel, cl.getNumAtoms());
    integration.computeVirtualSites();
507
508
}

509
510
511
512
void OpenCLApplyConstraintsKernel::applyToVelocities(ContextImpl& context, double tol) {
    cl.getIntegrationUtilities().applyVelocityConstraints(tol);
}

513
514
515
516
517
518
519
void OpenCLVirtualSitesKernel::initialize(const System& system) {
}

void OpenCLVirtualSitesKernel::computePositions(ContextImpl& context) {
    cl.getIntegrationUtilities().computeVirtualSites();
}

520
class OpenCLCalcHarmonicBondForceKernel::ForceInfo : public OpenCLForceInfo {
521
public:
522
    ForceInfo(const HarmonicBondForce& force) : OpenCLForceInfo(0), force(force) {
523
524
525
526
    }
    int getNumParticleGroups() {
        return force.getNumBonds();
    }
Peter Eastman's avatar
Peter Eastman committed
527
    void getParticlesInGroup(int index, vector<int>& particles) {
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
        int particle1, particle2;
        double length, k;
        force.getBondParameters(index, particle1, particle2, length, k);
        particles.resize(2);
        particles[0] = particle1;
        particles[1] = particle2;
    }
    bool areGroupsIdentical(int group1, int group2) {
        int particle1, particle2;
        double length1, length2, k1, k2;
        force.getBondParameters(group1, particle1, particle2, length1, k1);
        force.getBondParameters(group2, particle1, particle2, length2, k2);
        return (length1 == length2 && k1 == k2);
    }
private:
    const HarmonicBondForce& force;
};

void OpenCLCalcHarmonicBondForceKernel::initialize(const System& system, const HarmonicBondForce& force) {
547
548
549
550
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*force.getNumBonds()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*force.getNumBonds()/numContexts;
    numBonds = endIndex-startIndex;
551
552
    if (numBonds == 0)
        return;
Peter Eastman's avatar
Peter Eastman committed
553
    vector<vector<int> > atoms(numBonds, vector<int>(2));
peastman's avatar
peastman committed
554
    params.initialize<mm_float2>(cl, numBonds, "bondParams");
555
556
557
    vector<mm_float2> paramVector(numBonds);
    for (int i = 0; i < numBonds; i++) {
        double length, k;
Peter Eastman's avatar
Peter Eastman committed
558
        force.getBondParameters(startIndex+i, atoms[i][0], atoms[i][1], length, k);
559
        paramVector[i] = mm_float2((cl_float) length, (cl_float) k);
560
    }
peastman's avatar
peastman committed
561
    params.upload(paramVector);
Peter Eastman's avatar
Peter Eastman committed
562
    map<string, string> replacements;
563
    replacements["APPLY_PERIODIC"] = (force.usesPeriodicBoundaryConditions() ? "1" : "0");
564
    replacements["COMPUTE_FORCE"] = OpenCLKernelSources::harmonicBondForce;
peastman's avatar
peastman committed
565
    replacements["PARAMS"] = cl.getBondedUtilities().addArgument(params.getDeviceBuffer(), "float2");
566
    cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::bondForce, replacements), force.getForceGroup());
567
568
    info = new ForceInfo(force);
    cl.addForce(info);
569
570
}

571
double OpenCLCalcHarmonicBondForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
572
573
    return 0.0;
}
574

575
576
577
578
579
580
void OpenCLCalcHarmonicBondForceKernel::copyParametersToContext(ContextImpl& context, const HarmonicBondForce& force) {
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*force.getNumBonds()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*force.getNumBonds()/numContexts;
    if (numBonds != endIndex-startIndex)
        throw OpenMMException("updateParametersInContext: The number of bonds has changed");
581
582
    if (numBonds == 0)
        return;
583
584
585
586
587
588
589
590
591
592
    
    // Record the per-bond parameters.
    
    vector<mm_float2> paramVector(numBonds);
    for (int i = 0; i < numBonds; i++) {
        int atom1, atom2;
        double length, k;
        force.getBondParameters(startIndex+i, atom1, atom2, length, k);
        paramVector[i] = mm_float2((cl_float) length, (cl_float) k);
    }
peastman's avatar
peastman committed
593
    params.upload(paramVector);
594
595
596
    
    // Mark that the current reordering may be invalid.
    
597
    cl.invalidateMolecules(info);
598
599
}

600
class OpenCLCalcCustomBondForceKernel::ForceInfo : public OpenCLForceInfo {
601
public:
602
    ForceInfo(const CustomBondForce& force) : OpenCLForceInfo(0), force(force) {
603
604
605
606
    }
    int getNumParticleGroups() {
        return force.getNumBonds();
    }
Peter Eastman's avatar
Peter Eastman committed
607
    void getParticlesInGroup(int index, vector<int>& particles) {
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
        int particle1, particle2;
        vector<double> parameters;
        force.getBondParameters(index, particle1, particle2, parameters);
        particles.resize(2);
        particles[0] = particle1;
        particles[1] = particle2;
    }
    bool areGroupsIdentical(int group1, int group2) {
        int particle1, particle2;
        vector<double> parameters1, parameters2;
        force.getBondParameters(group1, particle1, particle2, parameters1);
        force.getBondParameters(group2, particle1, particle2, parameters2);
        for (int i = 0; i < (int) parameters1.size(); i++)
            if (parameters1[i] != parameters2[i])
                return false;
        return true;
    }
private:
    const CustomBondForce& force;
};

OpenCLCalcCustomBondForceKernel::~OpenCLCalcCustomBondForceKernel() {
    if (params != NULL)
        delete params;
}

void OpenCLCalcCustomBondForceKernel::initialize(const System& system, const CustomBondForce& force) {
635
636
637
638
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*force.getNumBonds()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*force.getNumBonds()/numContexts;
    numBonds = endIndex-startIndex;
639
640
    if (numBonds == 0)
        return;
641
    vector<vector<int> > atoms(numBonds, vector<int>(2));
642
643
    params = new OpenCLParameterSet(cl, force.getNumPerBondParameters(), numBonds, "customBondParams");
    vector<vector<cl_float> > paramVector(numBonds);
644
645
    for (int i = 0; i < numBonds; i++) {
        vector<double> parameters;
646
        force.getBondParameters(startIndex+i, atoms[i][0], atoms[i][1], parameters);
647
        paramVector[i].resize(parameters.size());
648
        for (int j = 0; j < (int) parameters.size(); j++)
649
            paramVector[i][j] = (cl_float) parameters[j];
650
    }
651
    params->setParameterValues(paramVector);
652
653
    info = new ForceInfo(force);
    cl.addForce(info);
654
655
656
657
658
659
660
661
662
663
664
665
666

    // Record information for the expressions.

    globalParamNames.resize(force.getNumGlobalParameters());
    globalParamValues.resize(force.getNumGlobalParameters());
    for (int i = 0; i < force.getNumGlobalParameters(); i++) {
        globalParamNames[i] = force.getGlobalParameterName(i);
        globalParamValues[i] = (cl_float) force.getGlobalParameterDefaultValue(i);
    }
    Lepton::ParsedExpression energyExpression = Lepton::Parser::parse(force.getEnergyFunction()).optimize();
    Lepton::ParsedExpression forceExpression = energyExpression.differentiate("r").optimize();
    map<string, Lepton::ParsedExpression> expressions;
    expressions["energy += "] = energyExpression;
667
    expressions["real dEdR = "] = forceExpression;
668
669
670
671
672
673
674

    // Create the kernels.

    map<string, string> variables;
    variables["r"] = "r";
    for (int i = 0; i < force.getNumPerBondParameters(); i++) {
        const string& name = force.getPerBondParameterName(i);
675
        variables[name] = "bondParams"+params->getParameterSuffix(i);
676
    }
677
    if (force.getNumGlobalParameters() > 0) {
peastman's avatar
peastman committed
678
679
680
        globals.initialize<cl_float>(cl, force.getNumGlobalParameters(), "customBondGlobals", CL_MEM_READ_ONLY);
        globals.upload(globalParamValues);
        string argName = cl.getBondedUtilities().addArgument(globals.getDeviceBuffer(), "float");
681
682
        for (int i = 0; i < force.getNumGlobalParameters(); i++) {
            const string& name = force.getGlobalParameterName(i);
683
            string value = argName+"["+cl.intToString(i)+"]";
684
685
            variables[name] = value;
        }
686
    }
687
688
689
690
691
692
    for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
        string paramName = force.getEnergyParameterDerivativeName(i);
        string derivVariable = cl.getBondedUtilities().addEnergyParameterDerivative(paramName);
        Lepton::ParsedExpression derivExpression = energyExpression.differentiate(paramName).optimize();
        expressions[derivVariable+" += "] = derivExpression;
    }
693
    stringstream compute;
694
695
    for (int i = 0; i < (int) params->getBuffers().size(); i++) {
        const OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i];
696
697
        string argName = cl.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType());
        compute<<buffer.getType()<<" bondParams"<<(i+1)<<" = "<<argName<<"[index];\n";
698
    }
peastman's avatar
peastman committed
699
700
    vector<const TabulatedFunction*> functions;
    vector<pair<string, string> > functionNames;
701
    compute << cl.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp");
702
    map<string, string> replacements;
703
    replacements["APPLY_PERIODIC"] = (force.usesPeriodicBoundaryConditions() ? "1" : "0");
704
    replacements["COMPUTE_FORCE"] = compute.str();
705
    cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::bondForce, replacements), force.getForceGroup());
706
707
}

708
double OpenCLCalcCustomBondForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
peastman's avatar
peastman committed
709
    if (globals.isInitialized()) {
710
        bool changed = false;
711
        for (int i = 0; i < (int) globalParamNames.size(); i++) {
712
713
714
715
716
717
            cl_float value = (cl_float) context.getParameter(globalParamNames[i]);
            if (value != globalParamValues[i])
                changed = true;
            globalParamValues[i] = value;
        }
        if (changed)
peastman's avatar
peastman committed
718
            globals.upload(globalParamValues);
719
720
721
722
    }
    return 0.0;
}

723
724
725
726
727
728
void OpenCLCalcCustomBondForceKernel::copyParametersToContext(ContextImpl& context, const CustomBondForce& force) {
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*force.getNumBonds()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*force.getNumBonds()/numContexts;
    if (numBonds != endIndex-startIndex)
        throw OpenMMException("updateParametersInContext: The number of bonds has changed");
729
730
    if (numBonds == 0)
        return;
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
    
    // Record the per-bond parameters.
    
    vector<vector<cl_float> > paramVector(numBonds);
    vector<double> parameters;
    for (int i = 0; i < numBonds; i++) {
        int atom1, atom2;
        force.getBondParameters(startIndex+i, atom1, atom2, parameters);
        paramVector[i].resize(parameters.size());
        for (int j = 0; j < (int) parameters.size(); j++)
            paramVector[i][j] = (cl_float) parameters[j];
    }
    params->setParameterValues(paramVector);
    
    // Mark that the current reordering may be invalid.
    
747
    cl.invalidateMolecules(info);
748
749
}

750
class OpenCLCalcHarmonicAngleForceKernel::ForceInfo : public OpenCLForceInfo {
751
public:
752
    ForceInfo(const HarmonicAngleForce& force) : OpenCLForceInfo(0), force(force) {
753
754
755
756
    }
    int getNumParticleGroups() {
        return force.getNumAngles();
    }
Peter Eastman's avatar
Peter Eastman committed
757
    void getParticlesInGroup(int index, vector<int>& particles) {
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
        int particle1, particle2, particle3;
        double angle, k;
        force.getAngleParameters(index, particle1, particle2, particle3, angle, k);
        particles.resize(3);
        particles[0] = particle1;
        particles[1] = particle2;
        particles[2] = particle3;
    }
    bool areGroupsIdentical(int group1, int group2) {
        int particle1, particle2, particle3;
        double angle1, angle2, k1, k2;
        force.getAngleParameters(group1, particle1, particle2, particle3, angle1, k1);
        force.getAngleParameters(group2, particle1, particle2, particle3, angle2, k2);
        return (angle1 == angle2 && k1 == k2);
    }
private:
    const HarmonicAngleForce& force;
};

void OpenCLCalcHarmonicAngleForceKernel::initialize(const System& system, const HarmonicAngleForce& force) {
778
779
780
781
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*force.getNumAngles()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*force.getNumAngles()/numContexts;
    numAngles = endIndex-startIndex;
782
783
    if (numAngles == 0)
        return;
Peter Eastman's avatar
Peter Eastman committed
784
    vector<vector<int> > atoms(numAngles, vector<int>(3));
peastman's avatar
peastman committed
785
    params.initialize<mm_float2>(cl, numAngles, "angleParams");
786
787
788
    vector<mm_float2> paramVector(numAngles);
    for (int i = 0; i < numAngles; i++) {
        double angle, k;
Peter Eastman's avatar
Peter Eastman committed
789
        force.getAngleParameters(startIndex+i, atoms[i][0], atoms[i][1], atoms[i][2], angle, k);
790
        paramVector[i] = mm_float2((cl_float) angle, (cl_float) k);
791
792

    }
peastman's avatar
peastman committed
793
    params.upload(paramVector);
Peter Eastman's avatar
Peter Eastman committed
794
    map<string, string> replacements;
795
    replacements["APPLY_PERIODIC"] = (force.usesPeriodicBoundaryConditions() ? "1" : "0");
796
    replacements["COMPUTE_FORCE"] = OpenCLKernelSources::harmonicAngleForce;
peastman's avatar
peastman committed
797
    replacements["PARAMS"] = cl.getBondedUtilities().addArgument(params.getDeviceBuffer(), "float2");
798
    cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::angleForce, replacements), force.getForceGroup());
799
800
    info = new ForceInfo(force);
    cl.addForce(info);
801
802
}

803
double OpenCLCalcHarmonicAngleForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
804
805
806
    return 0.0;
}

807
808
809
810
811
812
void OpenCLCalcHarmonicAngleForceKernel::copyParametersToContext(ContextImpl& context, const HarmonicAngleForce& force) {
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*force.getNumAngles()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*force.getNumAngles()/numContexts;
    if (numAngles != endIndex-startIndex)
        throw OpenMMException("updateParametersInContext: The number of angles has changed");
813
814
    if (numAngles == 0)
        return;
815
816
817
818
819
820
821
822
823
824
    
    // Record the per-angle parameters.
    
    vector<mm_float2> paramVector(numAngles);
    for (int i = 0; i < numAngles; i++) {
        int atom1, atom2, atom3;
        double angle, k;
        force.getAngleParameters(startIndex+i, atom1, atom2, atom3, angle, k);
        paramVector[i] = mm_float2((cl_float) angle, (cl_float) k);
    }
peastman's avatar
peastman committed
825
    params.upload(paramVector);
826
827
828
    
    // Mark that the current reordering may be invalid.
    
829
    cl.invalidateMolecules(info);
830
831
}

832
class OpenCLCalcCustomAngleForceKernel::ForceInfo : public OpenCLForceInfo {
833
public:
834
    ForceInfo(const CustomAngleForce& force) : OpenCLForceInfo(0), force(force) {
835
836
837
838
    }
    int getNumParticleGroups() {
        return force.getNumAngles();
    }
Peter Eastman's avatar
Peter Eastman committed
839
    void getParticlesInGroup(int index, vector<int>& particles) {
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
        int particle1, particle2, particle3;
        vector<double> parameters;
        force.getAngleParameters(index, particle1, particle2, particle3, parameters);
        particles.resize(3);
        particles[0] = particle1;
        particles[1] = particle2;
        particles[2] = particle3;
    }
    bool areGroupsIdentical(int group1, int group2) {
        int particle1, particle2, particle3;
        vector<double> parameters1, parameters2;
        force.getAngleParameters(group1, particle1, particle2, particle3, parameters1);
        force.getAngleParameters(group2, particle1, particle2, particle3, parameters2);
        for (int i = 0; i < (int) parameters1.size(); i++)
            if (parameters1[i] != parameters2[i])
                return false;
        return true;
    }
private:
    const CustomAngleForce& force;
};

OpenCLCalcCustomAngleForceKernel::~OpenCLCalcCustomAngleForceKernel() {
    if (params != NULL)
        delete params;
}

void OpenCLCalcCustomAngleForceKernel::initialize(const System& system, const CustomAngleForce& force) {
868
869
870
871
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*force.getNumAngles()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*force.getNumAngles()/numContexts;
    numAngles = endIndex-startIndex;
872
873
    if (numAngles == 0)
        return;
874
    vector<vector<int> > atoms(numAngles, vector<int>(3));
875
876
877
878
    params = new OpenCLParameterSet(cl, force.getNumPerAngleParameters(), numAngles, "customAngleParams");
    vector<vector<cl_float> > paramVector(numAngles);
    for (int i = 0; i < numAngles; i++) {
        vector<double> parameters;
879
        force.getAngleParameters(startIndex+i, atoms[i][0], atoms[i][1], atoms[i][2], parameters);
880
881
882
883
884
        paramVector[i].resize(parameters.size());
        for (int j = 0; j < (int) parameters.size(); j++)
            paramVector[i][j] = (cl_float) parameters[j];
    }
    params->setParameterValues(paramVector);
885
886
    info = new ForceInfo(force);
    cl.addForce(info);
887
888
889
890
891
892
893
894
895
896
897
898
899

    // Record information for the expressions.

    globalParamNames.resize(force.getNumGlobalParameters());
    globalParamValues.resize(force.getNumGlobalParameters());
    for (int i = 0; i < force.getNumGlobalParameters(); i++) {
        globalParamNames[i] = force.getGlobalParameterName(i);
        globalParamValues[i] = (cl_float) force.getGlobalParameterDefaultValue(i);
    }
    Lepton::ParsedExpression energyExpression = Lepton::Parser::parse(force.getEnergyFunction()).optimize();
    Lepton::ParsedExpression forceExpression = energyExpression.differentiate("theta").optimize();
    map<string, Lepton::ParsedExpression> expressions;
    expressions["energy += "] = energyExpression;
900
    expressions["real dEdAngle = "] = forceExpression;
901
902
903
904
905
906
907
908
909

    // Create the kernels.

    map<string, string> variables;
    variables["theta"] = "theta";
    for (int i = 0; i < force.getNumPerAngleParameters(); i++) {
        const string& name = force.getPerAngleParameterName(i);
        variables[name] = "angleParams"+params->getParameterSuffix(i);
    }
910
    if (force.getNumGlobalParameters() > 0) {
peastman's avatar
peastman committed
911
912
913
        globals.initialize<cl_float>(cl, force.getNumGlobalParameters(), "customAngleGlobals", CL_MEM_READ_ONLY);
        globals.upload(globalParamValues);
        string argName = cl.getBondedUtilities().addArgument(globals.getDeviceBuffer(), "float");
914
915
        for (int i = 0; i < force.getNumGlobalParameters(); i++) {
            const string& name = force.getGlobalParameterName(i);
916
            string value = argName+"["+cl.intToString(i)+"]";
917
918
            variables[name] = value;
        }
919
    }
920
921
922
923
924
925
    for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
        string paramName = force.getEnergyParameterDerivativeName(i);
        string derivVariable = cl.getBondedUtilities().addEnergyParameterDerivative(paramName);
        Lepton::ParsedExpression derivExpression = energyExpression.differentiate(paramName).optimize();
        expressions[derivVariable+" += "] = derivExpression;
    }
926
927
928
    stringstream compute;
    for (int i = 0; i < (int) params->getBuffers().size(); i++) {
        const OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i];
929
930
        string argName = cl.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType());
        compute<<buffer.getType()<<" angleParams"<<(i+1)<<" = "<<argName<<"[index];\n";
931
    }
peastman's avatar
peastman committed
932
933
    vector<const TabulatedFunction*> functions;
    vector<pair<string, string> > functionNames;
934
    compute << cl.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp");
935
    map<string, string> replacements;
936
    replacements["APPLY_PERIODIC"] = (force.usesPeriodicBoundaryConditions() ? "1" : "0");
937
    replacements["COMPUTE_FORCE"] = compute.str();
938
    cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::angleForce, replacements), force.getForceGroup());
939
940
}

941
double OpenCLCalcCustomAngleForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
peastman's avatar
peastman committed
942
    if (globals.isInitialized()) {
943
944
945
946
947
948
949
950
        bool changed = false;
        for (int i = 0; i < (int) globalParamNames.size(); i++) {
            cl_float value = (cl_float) context.getParameter(globalParamNames[i]);
            if (value != globalParamValues[i])
                changed = true;
            globalParamValues[i] = value;
        }
        if (changed)
peastman's avatar
peastman committed
951
            globals.upload(globalParamValues);
952
953
954
955
    }
    return 0.0;
}

956
957
958
959
960
961
void OpenCLCalcCustomAngleForceKernel::copyParametersToContext(ContextImpl& context, const CustomAngleForce& force) {
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*force.getNumAngles()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*force.getNumAngles()/numContexts;
    if (numAngles != endIndex-startIndex)
        throw OpenMMException("updateParametersInContext: The number of angles has changed");
962
963
    if (numAngles == 0)
        return;
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
    
    // Record the per-angle parameters.
    
    vector<vector<cl_float> > paramVector(numAngles);
    vector<double> parameters;
    for (int i = 0; i < numAngles; i++) {
        int atom1, atom2, atom3;
        force.getAngleParameters(startIndex+i, atom1, atom2, atom3, parameters);
        paramVector[i].resize(parameters.size());
        for (int j = 0; j < (int) parameters.size(); j++)
            paramVector[i][j] = (cl_float) parameters[j];
    }
    params->setParameterValues(paramVector);
    
    // Mark that the current reordering may be invalid.
    
980
    cl.invalidateMolecules(info);
981
982
}

983
class OpenCLCalcPeriodicTorsionForceKernel::ForceInfo : public OpenCLForceInfo {
984
public:
985
    ForceInfo(const PeriodicTorsionForce& force) : OpenCLForceInfo(0), force(force) {
986
987
988
989
    }
    int getNumParticleGroups() {
        return force.getNumTorsions();
    }
Peter Eastman's avatar
Peter Eastman committed
990
    void getParticlesInGroup(int index, vector<int>& particles) {
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
        int particle1, particle2, particle3, particle4, periodicity;
        double phase, k;
        force.getTorsionParameters(index, particle1, particle2, particle3, particle4, periodicity, phase, k);
        particles.resize(4);
        particles[0] = particle1;
        particles[1] = particle2;
        particles[2] = particle3;
        particles[3] = particle4;
    }
    bool areGroupsIdentical(int group1, int group2) {
        int particle1, particle2, particle3, particle4, periodicity1, periodicity2;
        double phase1, phase2, k1, k2;
        force.getTorsionParameters(group1, particle1, particle2, particle3, particle4, periodicity1, phase1, k1);
1004
        force.getTorsionParameters(group2, particle1, particle2, particle3, particle4, periodicity2, phase2, k2);
1005
1006
1007
1008
1009
1010
1011
        return (periodicity1 == periodicity2 && phase1 == phase2 && k1 == k2);
    }
private:
    const PeriodicTorsionForce& force;
};

void OpenCLCalcPeriodicTorsionForceKernel::initialize(const System& system, const PeriodicTorsionForce& force) {
1012
1013
1014
1015
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*force.getNumTorsions()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*force.getNumTorsions()/numContexts;
    numTorsions = endIndex-startIndex;
1016
1017
    if (numTorsions == 0)
        return;
Peter Eastman's avatar
Peter Eastman committed
1018
    vector<vector<int> > atoms(numTorsions, vector<int>(4));
peastman's avatar
peastman committed
1019
    params.initialize<mm_float4>(cl, numTorsions, "periodicTorsionParams");
1020
1021
    vector<mm_float4> paramVector(numTorsions);
    for (int i = 0; i < numTorsions; i++) {
Peter Eastman's avatar
Peter Eastman committed
1022
        int periodicity;
1023
        double phase, k;
Peter Eastman's avatar
Peter Eastman committed
1024
        force.getTorsionParameters(startIndex+i, atoms[i][0], atoms[i][1], atoms[i][2], atoms[i][3], periodicity, phase, k);
1025
        paramVector[i] = mm_float4((cl_float) k, (cl_float) phase, (cl_float) periodicity, 0.0f);
1026
    }
peastman's avatar
peastman committed
1027
    params.upload(paramVector);
Peter Eastman's avatar
Peter Eastman committed
1028
    map<string, string> replacements;
1029
    replacements["APPLY_PERIODIC"] = (force.usesPeriodicBoundaryConditions() ? "1" : "0");
1030
    replacements["COMPUTE_FORCE"] = OpenCLKernelSources::periodicTorsionForce;
peastman's avatar
peastman committed
1031
    replacements["PARAMS"] = cl.getBondedUtilities().addArgument(params.getDeviceBuffer(), "float4");
1032
    cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::torsionForce, replacements), force.getForceGroup());
1033
1034
    info = new ForceInfo(force);
    cl.addForce(info);
1035
1036
}

1037
double OpenCLCalcPeriodicTorsionForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
1038
1039
1040
    return 0.0;
}

1041
1042
1043
1044
1045
1046
void OpenCLCalcPeriodicTorsionForceKernel::copyParametersToContext(ContextImpl& context, const PeriodicTorsionForce& force) {
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*force.getNumTorsions()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*force.getNumTorsions()/numContexts;
    if (numTorsions != endIndex-startIndex)
        throw OpenMMException("updateParametersInContext: The number of torsions has changed");
1047
1048
    if (numTorsions == 0)
        return;
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
    
    // Record the per-torsion parameters.
    
    vector<mm_float4> paramVector(numTorsions);
    for (int i = 0; i < numTorsions; i++) {
        int atom1, atom2, atom3, atom4, periodicity;
        double phase, k;
        force.getTorsionParameters(startIndex+i, atom1, atom2, atom3, atom4, periodicity, phase, k);
        paramVector[i] = mm_float4((cl_float) k, (cl_float) phase, (cl_float) periodicity, 0.0f);
    }
peastman's avatar
peastman committed
1059
    params.upload(paramVector);
1060
1061
1062
    
    // Mark that the current reordering may be invalid.
    
1063
    cl.invalidateMolecules(info);
1064
1065
}

1066
class OpenCLCalcRBTorsionForceKernel::ForceInfo : public OpenCLForceInfo {
1067
public:
1068
    ForceInfo(const RBTorsionForce& force) : OpenCLForceInfo(0), force(force) {
1069
1070
1071
1072
    }
    int getNumParticleGroups() {
        return force.getNumTorsions();
    }
Peter Eastman's avatar
Peter Eastman committed
1073
    void getParticlesInGroup(int index, vector<int>& particles) {
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
        int particle1, particle2, particle3, particle4;
        double c0, c1, c2, c3, c4, c5;
        force.getTorsionParameters(index, particle1, particle2, particle3, particle4, c0, c1, c2, c3, c4, c5);
        particles.resize(4);
        particles[0] = particle1;
        particles[1] = particle2;
        particles[2] = particle3;
        particles[3] = particle4;
    }
    bool areGroupsIdentical(int group1, int group2) {
        int particle1, particle2, particle3, particle4;
        double c0a, c0b, c1a, c1b, c2a, c2b, c3a, c3b, c4a, c4b, c5a, c5b;
        force.getTorsionParameters(group1, particle1, particle2, particle3, particle4, c0a, c1a, c2a, c3a, c4a, c5a);
1087
        force.getTorsionParameters(group2, particle1, particle2, particle3, particle4, c0b, c1b, c2b, c3b, c4b, c5b);
1088
1089
1090
1091
1092
1093
1094
        return (c0a == c0b && c1a == c1b && c2a == c2b && c3a == c3b && c4a == c4b && c5a == c5b);
    }
private:
    const RBTorsionForce& force;
};

void OpenCLCalcRBTorsionForceKernel::initialize(const System& system, const RBTorsionForce& force) {
1095
1096
1097
1098
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*force.getNumTorsions()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*force.getNumTorsions()/numContexts;
    numTorsions = endIndex-startIndex;
1099
1100
    if (numTorsions == 0)
        return;
Peter Eastman's avatar
Peter Eastman committed
1101
    vector<vector<int> > atoms(numTorsions, vector<int>(4));
peastman's avatar
peastman committed
1102
    params.initialize<mm_float8>(cl, numTorsions, "rbTorsionParams");
1103
1104
1105
    vector<mm_float8> paramVector(numTorsions);
    for (int i = 0; i < numTorsions; i++) {
        double c0, c1, c2, c3, c4, c5;
Peter Eastman's avatar
Peter Eastman committed
1106
        force.getTorsionParameters(startIndex+i, atoms[i][0], atoms[i][1], atoms[i][2], atoms[i][3], c0, c1, c2, c3, c4, c5);
1107
        paramVector[i] = mm_float8((cl_float) c0, (cl_float) c1, (cl_float) c2, (cl_float) c3, (cl_float) c4, (cl_float) c5, 0.0f, 0.0f);
1108
1109

    }
peastman's avatar
peastman committed
1110
    params.upload(paramVector);
Peter Eastman's avatar
Peter Eastman committed
1111
    map<string, string> replacements;
1112
    replacements["APPLY_PERIODIC"] = (force.usesPeriodicBoundaryConditions() ? "1" : "0");
1113
    replacements["COMPUTE_FORCE"] = OpenCLKernelSources::rbTorsionForce;
peastman's avatar
peastman committed
1114
    replacements["PARAMS"] = cl.getBondedUtilities().addArgument(params.getDeviceBuffer(), "float8");
1115
    cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::torsionForce, replacements), force.getForceGroup());
1116
1117
    info = new ForceInfo(force);
    cl.addForce(info);
1118
1119
}

1120
double OpenCLCalcRBTorsionForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
1121
1122
1123
    return 0.0;
}

1124
1125
1126
1127
1128
1129
void OpenCLCalcRBTorsionForceKernel::copyParametersToContext(ContextImpl& context, const RBTorsionForce& force) {
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*force.getNumTorsions()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*force.getNumTorsions()/numContexts;
    if (numTorsions != endIndex-startIndex)
        throw OpenMMException("updateParametersInContext: The number of torsions has changed");
1130
1131
    if (numTorsions == 0)
        return;
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
    
    // Record the per-torsion parameters.
    
    vector<mm_float8> paramVector(numTorsions);
    for (int i = 0; i < numTorsions; i++) {
        int atom1, atom2, atom3, atom4;
        double c0, c1, c2, c3, c4, c5;
        force.getTorsionParameters(startIndex+i, atom1, atom2, atom3, atom4, c0, c1, c2, c3, c4, c5);
        paramVector[i] = mm_float8((cl_float) c0, (cl_float) c1, (cl_float) c2, (cl_float) c3, (cl_float) c4, (cl_float) c5, 0.0f, 0.0f);
    }
peastman's avatar
peastman committed
1142
    params.upload(paramVector);
1143
1144
1145
    
    // Mark that the current reordering may be invalid.
    
1146
    cl.invalidateMolecules(info);
1147
1148
}

1149
class OpenCLCalcCMAPTorsionForceKernel::ForceInfo : public OpenCLForceInfo {
1150
public:
1151
    ForceInfo(const CMAPTorsionForce& force) : OpenCLForceInfo(0), force(force) {
1152
1153
1154
1155
    }
    int getNumParticleGroups() {
        return force.getNumTorsions();
    }
Peter Eastman's avatar
Peter Eastman committed
1156
    void getParticlesInGroup(int index, vector<int>& particles) {
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
        int map, a1, a2, a3, a4, b1, b2, b3, b4;
        force.getTorsionParameters(index, map, a1, a2, a3, a4, b1, b2, b3, b4);
        particles.resize(8);
        particles[0] = a1;
        particles[1] = a2;
        particles[2] = a3;
        particles[3] = a4;
        particles[4] = b1;
        particles[5] = b2;
        particles[6] = b3;
        particles[7] = b4;
    }
    bool areGroupsIdentical(int group1, int group2) {
        int map1, map2, a1, a2, a3, a4, b1, b2, b3, b4;
        force.getTorsionParameters(group1, map1, a1, a2, a3, a4, b1, b2, b3, b4);
        force.getTorsionParameters(group2, map2, a1, a2, a3, a4, b1, b2, b3, b4);
        return (map1 == map2);
    }
private:
    const CMAPTorsionForce& force;
};

void OpenCLCalcCMAPTorsionForceKernel::initialize(const System& system, const CMAPTorsionForce& force) {
1180
1181
1182
1183
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*force.getNumTorsions()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*force.getNumTorsions()/numContexts;
    numTorsions = endIndex-startIndex;
1184
1185
1186
1187
    if (numTorsions == 0)
        return;
    int numMaps = force.getNumMaps();
    vector<mm_float4> coeffVec;
1188
    mapPositionsVec.resize(numMaps);
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
    vector<double> energy;
    vector<vector<double> > c;
    int currentPosition = 0;
    for (int i = 0; i < numMaps; i++) {
        int size;
        force.getMapParameters(i, size, energy);
        CMAPTorsionForceImpl::calcMapDerivatives(size, energy, c);
        mapPositionsVec[i] = mm_int2(currentPosition, size);
        currentPosition += 4*size*size;
        for (int j = 0; j < size*size; j++) {
1199
1200
1201
1202
            coeffVec.push_back(mm_float4((float) c[j][0], (float) c[j][1], (float) c[j][2], (float) c[j][3]));
            coeffVec.push_back(mm_float4((float) c[j][4], (float) c[j][5], (float) c[j][6], (float) c[j][7]));
            coeffVec.push_back(mm_float4((float) c[j][8], (float) c[j][9], (float) c[j][10], (float) c[j][11]));
            coeffVec.push_back(mm_float4((float) c[j][12], (float) c[j][13], (float) c[j][14], (float) c[j][15]));
1203
1204
        }
    }
1205
    vector<vector<int> > atoms(numTorsions, vector<int>(8));
1206
    vector<cl_int> torsionMapsVec(numTorsions);
1207
1208
    for (int i = 0; i < numTorsions; i++)
        force.getTorsionParameters(startIndex+i, torsionMapsVec[i], atoms[i][0], atoms[i][1], atoms[i][2], atoms[i][3], atoms[i][4], atoms[i][5], atoms[i][6], atoms[i][7]);
peastman's avatar
peastman committed
1209
1210
1211
1212
1213
1214
    coefficients.initialize<mm_float4>(cl, coeffVec.size(), "cmapTorsionCoefficients");
    mapPositions.initialize<mm_int2>(cl, numMaps, "cmapTorsionMapPositions");
    torsionMaps.initialize<cl_int>(cl, numTorsions, "cmapTorsionMaps");
    coefficients.upload(coeffVec);
    mapPositions.upload(mapPositionsVec);
    torsionMaps.upload(torsionMapsVec);
1215
    map<string, string> replacements;
1216
    replacements["APPLY_PERIODIC"] = (force.usesPeriodicBoundaryConditions() ? "1" : "0");
peastman's avatar
peastman committed
1217
1218
1219
    replacements["COEFF"] = cl.getBondedUtilities().addArgument(coefficients.getDeviceBuffer(), "float4");
    replacements["MAP_POS"] = cl.getBondedUtilities().addArgument(mapPositions.getDeviceBuffer(), "int2");
    replacements["MAPS"] = cl.getBondedUtilities().addArgument(torsionMaps.getDeviceBuffer(), "int");
1220
    cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::cmapTorsionForce, replacements), force.getForceGroup());
1221
1222
    info = new ForceInfo(force);
    cl.addForce(info);
1223
1224
}

1225
double OpenCLCalcCMAPTorsionForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
1226
1227
1228
    return 0.0;
}

1229
void OpenCLCalcCMAPTorsionForceKernel::copyParametersToContext(ContextImpl& context, const CMAPTorsionForce& force) {
1230
1231
1232
1233
1234
    int numMaps = force.getNumMaps();
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*force.getNumTorsions()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*force.getNumTorsions()/numContexts;
    numTorsions = endIndex-startIndex;
peastman's avatar
peastman committed
1235
    if (mapPositions.getSize() != numMaps)
1236
        throw OpenMMException("updateParametersInContext: The number of maps has changed");
peastman's avatar
peastman committed
1237
    if (torsionMaps.getSize() != numTorsions)
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
        throw OpenMMException("updateParametersInContext: The number of CMAP torsions has changed");

    // Update the maps.

    vector<mm_float4> coeffVec;
    vector<double> energy;
    vector<vector<double> > c;
    int currentPosition = 0;
    for (int i = 0; i < numMaps; i++) {
        int size;
        force.getMapParameters(i, size, energy);
        if (size != mapPositionsVec[i].y)
            throw OpenMMException("updateParametersInContext: The size of a map has changed");
        CMAPTorsionForceImpl::calcMapDerivatives(size, energy, c);
        currentPosition += 4*size*size;
        for (int j = 0; j < size*size; j++) {
            coeffVec.push_back(mm_float4((float) c[j][0], (float) c[j][1], (float) c[j][2], (float) c[j][3]));
            coeffVec.push_back(mm_float4((float) c[j][4], (float) c[j][5], (float) c[j][6], (float) c[j][7]));
            coeffVec.push_back(mm_float4((float) c[j][8], (float) c[j][9], (float) c[j][10], (float) c[j][11]));
            coeffVec.push_back(mm_float4((float) c[j][12], (float) c[j][13], (float) c[j][14], (float) c[j][15]));
        }
    }
peastman's avatar
peastman committed
1260
    coefficients.upload(coeffVec);
1261
1262
1263
1264
1265
1266
1267
1268

    // Update the indices.

    vector<int> torsionMapsVec(numTorsions);
    for (int i = 0; i < numTorsions; i++) {
        int index[8];
        force.getTorsionParameters(i, torsionMapsVec[i], index[0], index[1], index[2], index[3], index[4], index[5], index[6], index[7]);
    }
peastman's avatar
peastman committed
1269
    torsionMaps.upload(torsionMapsVec);
1270
1271
}

1272
class OpenCLCalcCustomTorsionForceKernel::ForceInfo : public OpenCLForceInfo {
1273
public:
1274
    ForceInfo(const CustomTorsionForce& force) : OpenCLForceInfo(0), force(force) {
1275
1276
1277
1278
    }
    int getNumParticleGroups() {
        return force.getNumTorsions();
    }
Peter Eastman's avatar
Peter Eastman committed
1279
    void getParticlesInGroup(int index, vector<int>& particles) {
1280
1281
1282
        int particle1, particle2, particle3, particle4;
        vector<double> parameters;
        force.getTorsionParameters(index, particle1, particle2, particle3, particle4, parameters);
Peter Eastman's avatar
Bug fix  
Peter Eastman committed
1283
        particles.resize(4);
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
        particles[0] = particle1;
        particles[1] = particle2;
        particles[2] = particle3;
        particles[3] = particle4;
    }
    bool areGroupsIdentical(int group1, int group2) {
        int particle1, particle2, particle3, particle4;
        vector<double> parameters1, parameters2;
        force.getTorsionParameters(group1, particle1, particle2, particle3, particle4, parameters1);
        force.getTorsionParameters(group2, particle1, particle2, particle3, particle4, parameters2);
        for (int i = 0; i < (int) parameters1.size(); i++)
            if (parameters1[i] != parameters2[i])
                return false;
        return true;
    }
private:
    const CustomTorsionForce& force;
};

OpenCLCalcCustomTorsionForceKernel::~OpenCLCalcCustomTorsionForceKernel() {
    if (params != NULL)
        delete params;
}

void OpenCLCalcCustomTorsionForceKernel::initialize(const System& system, const CustomTorsionForce& force) {
1309
1310
1311
1312
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*force.getNumTorsions()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*force.getNumTorsions()/numContexts;
    numTorsions = endIndex-startIndex;
1313
1314
    if (numTorsions == 0)
        return;
1315
    vector<vector<int> > atoms(numTorsions, vector<int>(4));
1316
1317
1318
1319
    params = new OpenCLParameterSet(cl, force.getNumPerTorsionParameters(), numTorsions, "customTorsionParams");
    vector<vector<cl_float> > paramVector(numTorsions);
    for (int i = 0; i < numTorsions; i++) {
        vector<double> parameters;
1320
        force.getTorsionParameters(startIndex+i, atoms[i][0], atoms[i][1], atoms[i][2], atoms[i][3], parameters);
1321
1322
1323
1324
1325
        paramVector[i].resize(parameters.size());
        for (int j = 0; j < (int) parameters.size(); j++)
            paramVector[i][j] = (cl_float) parameters[j];
    }
    params->setParameterValues(paramVector);
1326
1327
    info = new ForceInfo(force);
    cl.addForce(info);
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340

    // Record information for the expressions.

    globalParamNames.resize(force.getNumGlobalParameters());
    globalParamValues.resize(force.getNumGlobalParameters());
    for (int i = 0; i < force.getNumGlobalParameters(); i++) {
        globalParamNames[i] = force.getGlobalParameterName(i);
        globalParamValues[i] = (cl_float) force.getGlobalParameterDefaultValue(i);
    }
    Lepton::ParsedExpression energyExpression = Lepton::Parser::parse(force.getEnergyFunction()).optimize();
    Lepton::ParsedExpression forceExpression = energyExpression.differentiate("theta").optimize();
    map<string, Lepton::ParsedExpression> expressions;
    expressions["energy += "] = energyExpression;
1341
    expressions["real dEdAngle = "] = forceExpression;
1342
1343
1344
1345
1346
1347
1348
1349
1350

    // Create the kernels.

    map<string, string> variables;
    variables["theta"] = "theta";
    for (int i = 0; i < force.getNumPerTorsionParameters(); i++) {
        const string& name = force.getPerTorsionParameterName(i);
        variables[name] = "torsionParams"+params->getParameterSuffix(i);
    }
1351
    if (force.getNumGlobalParameters() > 0) {
peastman's avatar
peastman committed
1352
1353
1354
        globals.initialize<cl_float>(cl, force.getNumGlobalParameters(), "customTorsionGlobals", CL_MEM_READ_ONLY);
        globals.upload(globalParamValues);
        string argName = cl.getBondedUtilities().addArgument(globals.getDeviceBuffer(), "float");
1355
1356
        for (int i = 0; i < force.getNumGlobalParameters(); i++) {
            const string& name = force.getGlobalParameterName(i);
1357
            string value = argName+"["+cl.intToString(i)+"]";
1358
1359
            variables[name] = value;
        }
1360
    }
1361
1362
1363
1364
1365
1366
    for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
        string paramName = force.getEnergyParameterDerivativeName(i);
        string derivVariable = cl.getBondedUtilities().addEnergyParameterDerivative(paramName);
        Lepton::ParsedExpression derivExpression = energyExpression.differentiate(paramName).optimize();
        expressions[derivVariable+" += "] = derivExpression;
    }
1367
1368
1369
    stringstream compute;
    for (int i = 0; i < (int) params->getBuffers().size(); i++) {
        const OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i];
1370
1371
        string argName = cl.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType());
        compute<<buffer.getType()<<" torsionParams"<<(i+1)<<" = "<<argName<<"[index];\n";
1372
    }
peastman's avatar
peastman committed
1373
1374
    vector<const TabulatedFunction*> functions;
    vector<pair<string, string> > functionNames;
1375
    compute << cl.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp");
1376
    map<string, string> replacements;
1377
    replacements["APPLY_PERIODIC"] = (force.usesPeriodicBoundaryConditions() ? "1" : "0");
1378
    replacements["COMPUTE_FORCE"] = compute.str();
1379
    cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::torsionForce, replacements), force.getForceGroup());
1380
1381
}

1382
double OpenCLCalcCustomTorsionForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
peastman's avatar
peastman committed
1383
    if (globals.isInitialized()) {
1384
1385
1386
1387
1388
1389
1390
1391
        bool changed = false;
        for (int i = 0; i < (int) globalParamNames.size(); i++) {
            cl_float value = (cl_float) context.getParameter(globalParamNames[i]);
            if (value != globalParamValues[i])
                changed = true;
            globalParamValues[i] = value;
        }
        if (changed)
peastman's avatar
peastman committed
1392
            globals.upload(globalParamValues);
1393
1394
1395
1396
    }
    return 0.0;
}

1397
1398
1399
1400
1401
1402
void OpenCLCalcCustomTorsionForceKernel::copyParametersToContext(ContextImpl& context, const CustomTorsionForce& force) {
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*force.getNumTorsions()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*force.getNumTorsions()/numContexts;
    if (numTorsions != endIndex-startIndex)
        throw OpenMMException("updateParametersInContext: The number of torsions has changed");
1403
1404
    if (numTorsions == 0)
        return;
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
    
    // Record the per-torsion parameters.
    
    vector<vector<cl_float> > paramVector(numTorsions);
    vector<double> parameters;
    for (int i = 0; i < numTorsions; i++) {
        int atom1, atom2, atom3, atom4;
        force.getTorsionParameters(startIndex+i, atom1, atom2, atom3, atom4, parameters);
        paramVector[i].resize(parameters.size());
        for (int j = 0; j < (int) parameters.size(); j++)
            paramVector[i][j] = (cl_float) parameters[j];
    }
    params->setParameterValues(paramVector);
    
    // Mark that the current reordering may be invalid.
    
1421
    cl.invalidateMolecules(info);
1422
1423
}

1424
class OpenCLCalcNonbondedForceKernel::ForceInfo : public OpenCLForceInfo {
1425
public:
1426
    ForceInfo(int requiredBuffers, const NonbondedForce& force) : OpenCLForceInfo(requiredBuffers), force(force) {
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
    }
    bool areParticlesIdentical(int particle1, int particle2) {
        double charge1, charge2, sigma1, sigma2, epsilon1, epsilon2;
        force.getParticleParameters(particle1, charge1, sigma1, epsilon1);
        force.getParticleParameters(particle2, charge2, sigma2, epsilon2);
        return (charge1 == charge2 && sigma1 == sigma2 && epsilon1 == epsilon2);
    }
    int getNumParticleGroups() {
        return force.getNumExceptions();
    }
Peter Eastman's avatar
Peter Eastman committed
1437
    void getParticlesInGroup(int index, vector<int>& particles) {
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
        int particle1, particle2;
        double chargeProd, sigma, epsilon;
        force.getExceptionParameters(index, particle1, particle2, chargeProd, sigma, epsilon);
        particles.resize(2);
        particles[0] = particle1;
        particles[1] = particle2;
    }
    bool areGroupsIdentical(int group1, int group2) {
        int particle1, particle2;
        double chargeProd1, chargeProd2, sigma1, sigma2, epsilon1, epsilon2;
        force.getExceptionParameters(group1, particle1, particle2, chargeProd1, sigma1, epsilon1);
        force.getExceptionParameters(group2, particle1, particle2, chargeProd2, sigma2, epsilon2);
        return (chargeProd1 == chargeProd2 && sigma1 == sigma2 && epsilon1 == epsilon2);
    }
private:
    const NonbondedForce& force;
};

1456
1457
class OpenCLCalcNonbondedForceKernel::PmeIO : public CalcPmeReciprocalForceKernel::IO {
public:
peastman's avatar
peastman committed
1458
1459
1460
    PmeIO(OpenCLContext& cl, cl::Kernel addForcesKernel) : cl(cl), addForcesKernel(addForcesKernel) {
        forceTemp.initialize<mm_float4>(cl, cl.getNumAtoms(), "PmeForce");
        addForcesKernel.setArg<cl::Buffer>(0, forceTemp.getDeviceBuffer());
1461
1462
1463
1464
1465
1466
    }
    float* getPosq() {
        cl.getPosq().download(posq);
        return (float*) &posq[0];
    }
    void setForce(float* force) {
peastman's avatar
peastman committed
1467
        forceTemp.upload(force);
1468
1469
1470
1471
1472
1473
        addForcesKernel.setArg<cl::Buffer>(1, cl.getForce().getDeviceBuffer());
        cl.executeKernel(addForcesKernel, cl.getNumAtoms());
    }
private:
    OpenCLContext& cl;
    vector<mm_float4> posq;
peastman's avatar
peastman committed
1474
    OpenCLArray forceTemp;
1475
1476
1477
1478
1479
1480
1481
1482
    cl::Kernel addForcesKernel;
};

class OpenCLCalcNonbondedForceKernel::PmePreComputation : public OpenCLContext::ForcePreComputation {
public:
    PmePreComputation(OpenCLContext& cl, Kernel& pme, CalcPmeReciprocalForceKernel::IO& io) : cl(cl), pme(pme), io(io) {
    }
    void computeForceAndEnergy(bool includeForces, bool includeEnergy, int groups) {
peastman's avatar
peastman committed
1483
1484
        Vec3 boxVectors[3] = {Vec3(cl.getPeriodicBoxSize().x, 0, 0), Vec3(0, cl.getPeriodicBoxSize().y, 0), Vec3(0, 0, cl.getPeriodicBoxSize().z)};
        pme.getAs<CalcPmeReciprocalForceKernel>().beginComputation(io, boxVectors, includeEnergy);
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
    }
private:
    OpenCLContext& cl;
    Kernel pme;
    CalcPmeReciprocalForceKernel::IO& io;
};

class OpenCLCalcNonbondedForceKernel::PmePostComputation : public OpenCLContext::ForcePostComputation {
public:
    PmePostComputation(Kernel& pme, CalcPmeReciprocalForceKernel::IO& io) : pme(pme), io(io) {
    }
    double computeForceAndEnergy(bool includeForces, bool includeEnergy, int groups) {
        return pme.getAs<CalcPmeReciprocalForceKernel>().finishComputation(io);
    }
private:
    Kernel pme;
    CalcPmeReciprocalForceKernel::IO& io;
};

1504
1505
class OpenCLCalcNonbondedForceKernel::SyncQueuePreComputation : public OpenCLContext::ForcePreComputation {
public:
1506
    SyncQueuePreComputation(OpenCLContext& cl, cl::CommandQueue queue, int forceGroup) : cl(cl), queue(queue), forceGroup(forceGroup) {
1507
1508
    }
    void computeForceAndEnergy(bool includeForces, bool includeEnergy, int groups) {
peastman's avatar
Bug fix  
peastman committed
1509
        if ((groups&(1<<forceGroup)) != 0) {
1510
            vector<cl::Event> events(1);
peastman's avatar
Bug fix  
peastman committed
1511
1512
1513
            cl.getQueue().enqueueMarker(&events[0]);
            queue.enqueueWaitForEvents(events);
        }
1514
1515
1516
1517
    }
private:
    OpenCLContext& cl;
    cl::CommandQueue queue;
peastman's avatar
Bug fix  
peastman committed
1518
    int forceGroup;
1519
1520
1521
1522
};

class OpenCLCalcNonbondedForceKernel::SyncQueuePostComputation : public OpenCLContext::ForcePostComputation {
public:
1523
1524
1525
1526
1527
1528
1529
1530
    SyncQueuePostComputation(OpenCLContext& cl, cl::Event& event, OpenCLArray& pmeEnergyBuffer, int forceGroup) : cl(cl), event(event),
            pmeEnergyBuffer(pmeEnergyBuffer), forceGroup(forceGroup) {
    }
    void setKernel(cl::Kernel kernel) {
        addEnergyKernel = kernel;
        addEnergyKernel.setArg<cl::Buffer>(0, pmeEnergyBuffer.getDeviceBuffer());
        addEnergyKernel.setArg<cl::Buffer>(1, cl.getEnergyBuffer().getDeviceBuffer());
        addEnergyKernel.setArg<cl_int>(2, pmeEnergyBuffer.getSize());
1531
1532
    }
    double computeForceAndEnergy(bool includeForces, bool includeEnergy, int groups) {
peastman's avatar
Bug fix  
peastman committed
1533
        if ((groups&(1<<forceGroup)) != 0) {
1534
            vector<cl::Event> events(1);
peastman's avatar
Bug fix  
peastman committed
1535
            events[0] = event;
1536
            event = cl::Event();
peastman's avatar
Bug fix  
peastman committed
1537
            cl.getQueue().enqueueWaitForEvents(events);
1538
1539
            if (includeEnergy)
                cl.executeKernel(addEnergyKernel, pmeEnergyBuffer.getSize());
peastman's avatar
Bug fix  
peastman committed
1540
        }
1541
1542
1543
1544
1545
        return 0.0;
    }
private:
    OpenCLContext& cl;
    cl::Event& event;
1546
1547
    cl::Kernel addEnergyKernel;
    OpenCLArray& pmeEnergyBuffer;
peastman's avatar
Bug fix  
peastman committed
1548
    int forceGroup;
1549
1550
};

1551
OpenCLCalcNonbondedForceKernel::~OpenCLCalcNonbondedForceKernel() {
1552
1553
1554
1555
    if (sort != NULL)
        delete sort;
    if (fft != NULL)
        delete fft;
1556
1557
    if (dispersionFft != NULL)
        delete dispersionFft;
1558
1559
    if (pmeio != NULL)
        delete pmeio;
1560
1561
1562
}

void OpenCLCalcNonbondedForceKernel::initialize(const System& system, const NonbondedForce& force) {
1563
1564
1565
1566
    int forceIndex;
    for (forceIndex = 0; forceIndex < system.getNumForces() && &system.getForce(forceIndex) != &force; ++forceIndex)
        ;
    string prefix = "nonbonded"+cl.intToString(forceIndex)+"_";
1567
1568
1569

    // Identify which exceptions are 1-4 interactions.

1570
1571
1572
1573
1574
1575
1576
1577
    set<int> exceptionsWithOffsets;
    for (int i = 0; i < force.getNumExceptionParameterOffsets(); i++) {
        string param;
        int exception;
        double charge, sigma, epsilon;
        force.getExceptionParameterOffset(i, param, exception, charge, sigma, epsilon);
        exceptionsWithOffsets.insert(exception);
    }
1578
1579
    vector<pair<int, int> > exclusions;
    vector<int> exceptions;
1580
    map<int, int> exceptionIndex;
1581
1582
1583
1584
1585
    for (int i = 0; i < force.getNumExceptions(); i++) {
        int particle1, particle2;
        double chargeProd, sigma, epsilon;
        force.getExceptionParameters(i, particle1, particle2, chargeProd, sigma, epsilon);
        exclusions.push_back(pair<int, int>(particle1, particle2));
1586
1587
        if (chargeProd != 0.0 || epsilon != 0.0 || exceptionsWithOffsets.find(i) != exceptionsWithOffsets.end()) {
            exceptionIndex[i] = exceptions.size();
1588
            exceptions.push_back(i);
1589
        }
1590
1591
1592
1593
1594
    }

    // Initialize nonbonded interactions.

    int numParticles = force.getNumParticles();
1595
    vector<mm_float4> baseParticleParamVec(cl.getPaddedNumAtoms(), mm_float4(0, 0, 0, 0));
1596
    vector<vector<int> > exclusionList(numParticles);
1597
1598
    hasCoulomb = false;
    hasLJ = false;
1599
1600
1601
    for (int i = 0; i < numParticles; i++) {
        double charge, sigma, epsilon;
        force.getParticleParameters(i, charge, sigma, epsilon);
1602
        baseParticleParamVec[i] = mm_float4(charge, sigma, epsilon, 0);
1603
        exclusionList[i].push_back(i);
1604
1605
1606
1607
        if (charge != 0.0)
            hasCoulomb = true;
        if (epsilon != 0.0)
            hasLJ = true;
1608
    }
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
    for (int i = 0; i < force.getNumParticleParameterOffsets(); i++) {
        string param;
        int particle;
        double charge, sigma, epsilon;
        force.getParticleParameterOffset(i, param, particle, charge, sigma, epsilon);
        if (charge != 0.0)
            hasCoulomb = true;
        if (epsilon != 0.0)
            hasLJ = true;
    }
peastman's avatar
peastman committed
1619
1620
1621
    for (auto exclusion : exclusions) {
        exclusionList[exclusion.first].push_back(exclusion.second);
        exclusionList[exclusion.second].push_back(exclusion.first);
1622
    }
1623
1624
1625
    nonbondedMethod = CalcNonbondedForceKernel::NonbondedMethod(force.getNonbondedMethod());
    bool useCutoff = (nonbondedMethod != NoCutoff);
    bool usePeriodic = (nonbondedMethod != NoCutoff && nonbondedMethod != CutoffNonPeriodic);
1626
    doLJPME = (nonbondedMethod == LJPME && hasLJ);
1627
    usePosqCharges = hasCoulomb ? cl.requestPosqCharges() : false;
1628
    map<string, string> defines;
1629
1630
    defines["HAS_COULOMB"] = (hasCoulomb ? "1" : "0");
    defines["HAS_LENNARD_JONES"] = (hasLJ ? "1" : "0");
1631
    defines["USE_LJ_SWITCH"] = (useCutoff && force.getUseSwitchingFunction() ? "1" : "0");
1632
    if (useCutoff) {
1633
1634
        // Compute the reaction field constants.

1635
1636
        double reactionFieldK = pow(force.getCutoffDistance(), -3.0)*(force.getReactionFieldDielectric()-1.0)/(2.0*force.getReactionFieldDielectric()+1.0);
        double reactionFieldC = (1.0 / force.getCutoffDistance())*(3.0*force.getReactionFieldDielectric())/(2.0*force.getReactionFieldDielectric()+1.0);
1637
1638
        defines["REACTION_FIELD_K"] = cl.doubleToString(reactionFieldK);
        defines["REACTION_FIELD_C"] = cl.doubleToString(reactionFieldC);
1639
1640
1641
1642
1643
1644
1645
1646
1647
        
        // Compute the switching coefficients.
        
        if (force.getUseSwitchingFunction()) {
            defines["LJ_SWITCH_CUTOFF"] = cl.doubleToString(force.getSwitchingDistance());
            defines["LJ_SWITCH_C3"] = cl.doubleToString(10/pow(force.getSwitchingDistance()-force.getCutoffDistance(), 3.0));
            defines["LJ_SWITCH_C4"] = cl.doubleToString(15/pow(force.getSwitchingDistance()-force.getCutoffDistance(), 4.0));
            defines["LJ_SWITCH_C5"] = cl.doubleToString(6/pow(force.getSwitchingDistance()-force.getCutoffDistance(), 5.0));
        }
1648
    }
1649
    if (force.getUseDispersionCorrection() && cl.getContextIndex() == 0 && !doLJPME)
1650
1651
1652
        dispersionCoefficient = NonbondedForceImpl::calcDispersionCorrection(system, force);
    else
        dispersionCoefficient = 0.0;
1653
    alpha = 0;
1654
    ewaldSelfEnergy = 0.0;
1655
    map<string, string> paramsDefines;
Peter Eastman's avatar
Bug fix  
Peter Eastman committed
1656
1657
    hasOffsets = (force.getNumParticleParameterOffsets() > 0 || force.getNumExceptionParameterOffsets() > 0);
    if (hasOffsets)
1658
        paramsDefines["HAS_OFFSETS"] = "1";
1659
1660
    if (usePosqCharges)
        paramsDefines["USE_POSQ_CHARGES"] = "1";
1661
    if (nonbondedMethod == Ewald) {
1662
1663
1664
1665
        // Compute the Ewald parameters.

        int kmaxx, kmaxy, kmaxz;
        NonbondedForceImpl::calcEwaldParameters(system, force, alpha, kmaxx, kmaxy, kmaxz);
1666
1667
        defines["EWALD_ALPHA"] = cl.doubleToString(alpha);
        defines["TWO_OVER_SQRT_PI"] = cl.doubleToString(2.0/sqrt(M_PI));
1668
        defines["USE_EWALD"] = "1";
1669
        if (cl.getContextIndex() == 0) {
1670
1671
            paramsDefines["INCLUDE_EWALD"] = "1";
            paramsDefines["EWALD_SELF_ENERGY_SCALE"] = cl.doubleToString(ONE_4PI_EPS0*alpha/sqrt(M_PI));
Peter Eastman's avatar
Peter Eastman committed
1672
            for (int i = 0; i < numParticles; i++)
Peter Eastman's avatar
Bug fix  
Peter Eastman committed
1673
                ewaldSelfEnergy -= baseParticleParamVec[i].x*baseParticleParamVec[i].x*ONE_4PI_EPS0*alpha/sqrt(M_PI);
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686

            // Create the reciprocal space kernels.

            map<string, string> replacements;
            replacements["NUM_ATOMS"] = cl.intToString(numParticles);
            replacements["KMAX_X"] = cl.intToString(kmaxx);
            replacements["KMAX_Y"] = cl.intToString(kmaxy);
            replacements["KMAX_Z"] = cl.intToString(kmaxz);
            replacements["EXP_COEFFICIENT"] = cl.doubleToString(-1.0/(4.0*alpha*alpha));
            cl::Program program = cl.createProgram(OpenCLKernelSources::ewald, replacements);
            ewaldSumsKernel = cl::Kernel(program, "calculateEwaldCosSinSums");
            ewaldForcesKernel = cl::Kernel(program, "calculateEwaldForces");
            int elementSize = (cl.getUseDoublePrecision() ? sizeof(mm_double2) : sizeof(mm_float2));
peastman's avatar
peastman committed
1687
            cosSinSums.initialize(cl, (2*kmaxx-1)*(2*kmaxy-1)*(2*kmaxz-1), elementSize, "cosSinSums");
1688
1689
        }
    }
peastman's avatar
peastman committed
1690
    else if (((nonbondedMethod == PME || nonbondedMethod == LJPME) && hasCoulomb) || doLJPME) {
1691
1692
        // Compute the PME parameters.

1693
        NonbondedForceImpl::calcPMEParameters(system, force, alpha, gridSizeX, gridSizeY, gridSizeZ, false);
1694
1695
1696
        gridSizeX = OpenCLFFT3D::findLegalDimension(gridSizeX);
        gridSizeY = OpenCLFFT3D::findLegalDimension(gridSizeY);
        gridSizeZ = OpenCLFFT3D::findLegalDimension(gridSizeZ);
1697
1698
1699
1700
1701
1702
1703
        if (doLJPME) {
            NonbondedForceImpl::calcPMEParameters(system, force, dispersionAlpha, dispersionGridSizeX,
                                                  dispersionGridSizeY, dispersionGridSizeZ, true);
            dispersionGridSizeX = OpenCLFFT3D::findLegalDimension(dispersionGridSizeX);
            dispersionGridSizeY = OpenCLFFT3D::findLegalDimension(dispersionGridSizeY);
            dispersionGridSizeZ = OpenCLFFT3D::findLegalDimension(dispersionGridSizeZ);
        }
1704
1705
        defines["EWALD_ALPHA"] = cl.doubleToString(alpha);
        defines["TWO_OVER_SQRT_PI"] = cl.doubleToString(2.0/sqrt(M_PI));
1706
        defines["USE_EWALD"] = "1";
1707
1708
1709
        defines["DO_LJPME"] = doLJPME ? "1" : "0";
        if (doLJPME)
            defines["EWALD_DISPERSION_ALPHA"] = cl.doubleToString(dispersionAlpha);
1710
        if (cl.getContextIndex() == 0) {
1711
1712
            paramsDefines["INCLUDE_EWALD"] = "1";
            paramsDefines["EWALD_SELF_ENERGY_SCALE"] = cl.doubleToString(ONE_4PI_EPS0*alpha/sqrt(M_PI));
Peter Eastman's avatar
Peter Eastman committed
1713
            for (int i = 0; i < numParticles; i++)
1714
                ewaldSelfEnergy -= baseParticleParamVec[i].x*baseParticleParamVec[i].x*ONE_4PI_EPS0*alpha/sqrt(M_PI);
1715
1716
1717
            if (doLJPME) {
                paramsDefines["INCLUDE_LJPME"] = "1";
                paramsDefines["LJPME_SELF_ENERGY_SCALE"] = cl.doubleToString(pow(dispersionAlpha, 6)/3.0);
Peter Eastman's avatar
Peter Eastman committed
1718
1719
                for (int i = 0; i < numParticles; i++)
                    ewaldSelfEnergy += baseParticleParamVec[i].z*pow(baseParticleParamVec[i].y*dispersionAlpha, 6)/3.0;
1720
            }
1721
1722
1723
1724
1725
1726
1727
            pmeDefines["PME_ORDER"] = cl.intToString(PmeOrder);
            pmeDefines["NUM_ATOMS"] = cl.intToString(numParticles);
            pmeDefines["RECIP_EXP_FACTOR"] = cl.doubleToString(M_PI*M_PI/(alpha*alpha));
            pmeDefines["GRID_SIZE_X"] = cl.intToString(gridSizeX);
            pmeDefines["GRID_SIZE_Y"] = cl.intToString(gridSizeY);
            pmeDefines["GRID_SIZE_Z"] = cl.intToString(gridSizeZ);
            pmeDefines["EPSILON_FACTOR"] = cl.doubleToString(sqrt(ONE_4PI_EPS0));
1728
            pmeDefines["M_PI"] = cl.doubleToString(M_PI);
1729
1730
1731
            bool deviceIsCpu = (cl.getDevice().getInfo<CL_DEVICE_TYPE>() == CL_DEVICE_TYPE_CPU);
            if (deviceIsCpu)
                pmeDefines["DEVICE_IS_CPU"] = "1";
1732
            if (cl.getPlatformData().useCpuPme && !doLJPME && usePosqCharges) {
1733
1734
1735
1736
                // Create the CPU PME kernel.

                try {
                    cpuPme = getPlatform().createKernel(CalcPmeReciprocalForceKernel::Name(), *cl.getPlatformData().context);
1737
                    cpuPme.getAs<CalcPmeReciprocalForceKernel>().initialize(gridSizeX, gridSizeY, gridSizeZ, numParticles, alpha, false);
1738
1739
1740
1741
1742
                    cl::Program program = cl.createProgram(OpenCLKernelSources::pme, pmeDefines);
                    cl::Kernel addForcesKernel = cl::Kernel(program, "addForces");
                    pmeio = new PmeIO(cl, addForcesKernel);
                    cl.addPreComputation(new PmePreComputation(cl, cpuPme, *pmeio));
                    cl.addPostComputation(new PmePostComputation(cpuPme, *pmeio));
1743
                }
1744
1745
                catch (OpenMMException& ex) {
                    // The CPU PME plugin isn't available.
1746
                }
1747
1748
1749
1750
            }
            if (pmeio == NULL) {
                // Create required data structures.

1751
1752
1753
1754
1755
1756
1757
1758
1759
                if (doLJPME) {
                    double invRCut6 = pow(force.getCutoffDistance(), -6);
                    double dalphaR = dispersionAlpha * force.getCutoffDistance();
                    double dar2 = dalphaR*dalphaR;
                    double dar4 = dar2*dar2;
                    double multShift6 = -invRCut6*(1.0 - exp(-dar2) * (1.0 + dar2 + 0.5*dar4));
                    defines["INVCUT6"] = cl.doubleToString(invRCut6);
                    defines["MULTSHIFT6"] = cl.doubleToString(multShift6);
                }
1760
                int elementSize = (cl.getUseDoublePrecision() ? sizeof(double) : sizeof(float));
Peter Eastman's avatar
Peter Eastman committed
1761
1762
1763
1764
1765
1766
1767
                int roundedZSize = PmeOrder*(int) ceil(gridSizeZ/(double) PmeOrder);
                int gridElements = gridSizeX*gridSizeY*roundedZSize;
                if (doLJPME) {
                    roundedZSize = PmeOrder*(int) ceil(dispersionGridSizeZ/(double) PmeOrder);
                    gridElements = max(gridElements, dispersionGridSizeX*dispersionGridSizeY*roundedZSize);
                }
                pmeGrid1.initialize(cl, gridElements, 2*elementSize, "pmeGrid1");
peastman's avatar
peastman committed
1768
                pmeGrid2.initialize(cl, gridElements, 2*elementSize, "pmeGrid2");
peastman's avatar
peastman committed
1769
                if (cl.getSupports64BitGlobalAtomics())
peastman's avatar
peastman committed
1770
                    cl.addAutoclearBuffer(pmeGrid2);
peastman's avatar
peastman committed
1771
                else
Peter Eastman's avatar
Peter Eastman committed
1772
                    cl.addAutoclearBuffer(pmeGrid1);
peastman's avatar
peastman committed
1773
1774
1775
                pmeBsplineModuliX.initialize(cl, gridSizeX, elementSize, "pmeBsplineModuliX");
                pmeBsplineModuliY.initialize(cl, gridSizeY, elementSize, "pmeBsplineModuliY");
                pmeBsplineModuliZ.initialize(cl, gridSizeZ, elementSize, "pmeBsplineModuliZ");
1776
                if (doLJPME) {
peastman's avatar
peastman committed
1777
1778
1779
                    pmeDispersionBsplineModuliX.initialize(cl, dispersionGridSizeX, elementSize, "pmeDispersionBsplineModuliX");
                    pmeDispersionBsplineModuliY.initialize(cl, dispersionGridSizeY, elementSize, "pmeDispersionBsplineModuliY");
                    pmeDispersionBsplineModuliZ.initialize(cl, dispersionGridSizeZ, elementSize, "pmeDispersionBsplineModuliZ");
1780
                }
peastman's avatar
peastman committed
1781
1782
1783
                pmeBsplineTheta.initialize(cl, PmeOrder*numParticles, 4*elementSize, "pmeBsplineTheta");
                pmeAtomRange.initialize<cl_int>(cl, gridSizeX*gridSizeY*gridSizeZ+1, "pmeAtomRange");
                pmeAtomGridIndex.initialize<mm_int2>(cl, numParticles, "pmeAtomGridIndex");
1784
                int energyElementSize = (cl.getUseDoublePrecision() || cl.getUseMixedPrecision() ? sizeof(double) : sizeof(float));
peastman's avatar
peastman committed
1785
                pmeEnergyBuffer.initialize(cl, cl.getNumThreadBlocks()*OpenCLContext::ThreadBlockSize, energyElementSize, "pmeEnergyBuffer");
1786
                cl.clearBuffer(pmeEnergyBuffer);
1787
                sort = new OpenCLSort(cl, new SortTrait(), cl.getNumAtoms());
1788
                fft = new OpenCLFFT3D(cl, gridSizeX, gridSizeY, gridSizeZ, true);
1789
1790
                if (doLJPME)
                    dispersionFft = new OpenCLFFT3D(cl, dispersionGridSizeX, dispersionGridSizeY, dispersionGridSizeZ, true);
1791
                string vendor = cl.getDevice().getInfo<CL_DEVICE_VENDOR>();
Peter Eastman's avatar
Peter Eastman committed
1792
                bool isNvidia = (vendor.size() >= 6 && vendor.substr(0, 6) == "NVIDIA");
1793
                usePmeQueue = (!cl.getPlatformData().disablePmeStream && isNvidia);
1794
                if (usePmeQueue) {
peastman's avatar
peastman committed
1795
                    pmeDefines["USE_PME_STREAM"] = "1";
1796
1797
1798
1799
1800
                    pmeQueue = cl::CommandQueue(cl.getContext(), cl.getDevice());
                    int recipForceGroup = force.getReciprocalSpaceForceGroup();
                    if (recipForceGroup < 0)
                        recipForceGroup = force.getForceGroup();
                    cl.addPreComputation(new SyncQueuePreComputation(cl, pmeQueue, recipForceGroup));
peastman's avatar
peastman committed
1801
                    cl.addPostComputation(syncQueue = new SyncQueuePostComputation(cl, pmeSyncEvent, pmeEnergyBuffer, recipForceGroup));
1802
                }
1803
1804
1805

                // Initialize the b-spline moduli.

1806
1807
1808
1809
1810
1811
1812
                for (int grid = 0; grid < 2; grid++) {
                    int xsize, ysize, zsize;
                    OpenCLArray *xmoduli, *ymoduli, *zmoduli;
                    if (grid == 0) {
                        xsize = gridSizeX;
                        ysize = gridSizeY;
                        zsize = gridSizeZ;
peastman's avatar
peastman committed
1813
1814
1815
                        xmoduli = &pmeBsplineModuliX;
                        ymoduli = &pmeBsplineModuliY;
                        zmoduli = &pmeBsplineModuliZ;
1816
                    }
1817
1818
1819
1820
1821
1822
                    else {
                        if (!doLJPME)
                            continue;
                        xsize = dispersionGridSizeX;
                        ysize = dispersionGridSizeY;
                        zsize = dispersionGridSizeZ;
peastman's avatar
peastman committed
1823
1824
1825
                        xmoduli = &pmeDispersionBsplineModuliX;
                        ymoduli = &pmeDispersionBsplineModuliY;
                        zmoduli = &pmeDispersionBsplineModuliZ;
1826
                    }
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
                    int maxSize = max(max(xsize, ysize), zsize);
                    vector<double> data(PmeOrder);
                    vector<double> ddata(PmeOrder);
                    vector<double> bsplines_data(maxSize);
                    data[PmeOrder-1] = 0.0;
                    data[1] = 0.0;
                    data[0] = 1.0;
                    for (int i = 3; i < PmeOrder; i++) {
                        double div = 1.0/(i-1.0);
                        data[i-1] = 0.0;
                        for (int j = 1; j < (i-1); j++)
                            data[i-j-1] = div*(j*data[i-j-2]+(i-j)*data[i-j-1]);
                        data[0] = div*data[0];
1840
                    }
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858

                    // Differentiate.

                    ddata[0] = -data[0];
                    for (int i = 1; i < PmeOrder; i++)
                        ddata[i] = data[i-1]-data[i];
                    double div = 1.0/(PmeOrder-1);
                    data[PmeOrder-1] = 0.0;
                    for (int i = 1; i < (PmeOrder-1); i++)
                        data[PmeOrder-i-1] = div*(i*data[PmeOrder-i-2]+(PmeOrder-i)*data[PmeOrder-i-1]);
                    data[0] = div*data[0];
                    for (int i = 0; i < maxSize; i++)
                        bsplines_data[i] = 0.0;
                    for (int i = 1; i <= PmeOrder; i++)
                        bsplines_data[i] = data[i-1];

                    // Evaluate the actual bspline moduli for X/Y/Z.

1859
                    for (int dim = 0; dim < 3; dim++) {
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
                        int ndata = (dim == 0 ? xsize : dim == 1 ? ysize : zsize);
                        vector<cl_double> moduli(ndata);
                        for (int i = 0; i < ndata; i++) {
                            double sc = 0.0;
                            double ss = 0.0;
                            for (int j = 0; j < ndata; j++) {
                                double arg = (2.0*M_PI*i*j)/ndata;
                                sc += bsplines_data[j]*cos(arg);
                                ss += bsplines_data[j]*sin(arg);
                            }
peastman's avatar
peastman committed
1870
                            moduli[i] = sc*sc+ss*ss;
1871
                        }
1872
                        for (int i = 0; i < ndata; i++)
1873
1874
1875
1876
                        {
                            if (moduli[i] < 1.0e-7)
                                moduli[i] = (moduli[i-1]+moduli[i+1])*0.5f;
                        }
peastman's avatar
peastman committed
1877
1878
1879
1880
1881
1882
                        if (dim == 0)
                            xmoduli->upload(moduli, true, true);
                        else if (dim == 1)
                            ymoduli->upload(moduli, true, true);
                        else
                            zmoduli->upload(moduli, true, true);
1883
                    }
1884
                }
1885
            }
1886
1887
        }
    }
1888

1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
    // Add code to subtract off the reciprocal part of excluded interactions.

    if ((nonbondedMethod == Ewald || nonbondedMethod == PME || nonbondedMethod == LJPME) && pmeio == NULL) {
        int numContexts = cl.getPlatformData().contexts.size();
        int startIndex = cl.getContextIndex()*force.getNumExceptions()/numContexts;
        int endIndex = (cl.getContextIndex()+1)*force.getNumExceptions()/numContexts;
        int numExclusions = endIndex-startIndex;
        if (numExclusions > 0) {
            paramsDefines["HAS_EXCLUSIONS"] = "1";
            vector<vector<int> > atoms(numExclusions, vector<int>(2));
            exclusionAtoms.initialize<mm_int2>(cl, numExclusions, "exclusionAtoms");
            exclusionParams.initialize<mm_float4>(cl, numExclusions, "exclusionParams");
            vector<mm_int2> exclusionAtomsVec(numExclusions);
            for (int i = 0; i < numExclusions; i++) {
                int j = i+startIndex;
                exclusionAtomsVec[i] = mm_int2(exclusions[j].first, exclusions[j].second);
                atoms[i][0] = exclusions[j].first;
                atoms[i][1] = exclusions[j].second;
            }
            exclusionAtoms.upload(exclusionAtomsVec);
            map<string, string> replacements;
            replacements["PARAMS"] = cl.getBondedUtilities().addArgument(exclusionParams.getDeviceBuffer(), "float4");
            replacements["EWALD_ALPHA"] = cl.doubleToString(alpha);
            replacements["TWO_OVER_SQRT_PI"] = cl.doubleToString(2.0/sqrt(M_PI));
            replacements["DO_LJPME"] = doLJPME ? "1" : "0";
            if (doLJPME)
                replacements["EWALD_DISPERSION_ALPHA"] = cl.doubleToString(dispersionAlpha);
            cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::pmeExclusions, replacements), force.getForceGroup());
        }
    }

1920
1921
    // Add the interaction to the default nonbonded kernel.
    
1922
    string source = cl.replaceStrings(OpenCLKernelSources::coulombLennardJones, defines);
1923
    charges.initialize(cl, cl.getPaddedNumAtoms(), cl.getUseDoublePrecision() ? sizeof(double) : sizeof(float), "charges");
1924
1925
    baseParticleParams.initialize<mm_float4>(cl, cl.getPaddedNumAtoms(), "baseParticleParams");
    baseParticleParams.upload(baseParticleParamVec);
peastman's avatar
peastman committed
1926
1927
1928
1929
1930
1931
1932
1933
    map<string, string> replacements;
    if (usePosqCharges) {
        replacements["CHARGE1"] = "posq1.w";
        replacements["CHARGE2"] = "posq2.w";
    }
    else {
        replacements["CHARGE1"] = prefix+"charge1";
        replacements["CHARGE2"] = prefix+"charge2";
1934
    }
peastman's avatar
peastman committed
1935
1936
    if (hasCoulomb)
        cl.getNonbondedUtilities().addParameter(OpenCLNonbondedUtilities::ParameterInfo(prefix+"charge", "real", 1, charges.getElementSize(), charges.getDeviceBuffer()));
1937
    sigmaEpsilon.initialize<mm_float2>(cl, cl.getPaddedNumAtoms(), "sigmaEpsilon");
1938
1939
1940
1941
1942
    if (hasLJ) {
        replacements["SIGMA_EPSILON1"] = prefix+"sigmaEpsilon1";
        replacements["SIGMA_EPSILON2"] = prefix+"sigmaEpsilon2";
        cl.getNonbondedUtilities().addParameter(OpenCLNonbondedUtilities::ParameterInfo(prefix+"sigmaEpsilon", "float", 2, sizeof(cl_float2), sigmaEpsilon.getDeviceBuffer()));
    }
peastman's avatar
peastman committed
1943
    source = cl.replaceStrings(source, replacements);
1944
    cl.getNonbondedUtilities().addInteraction(useCutoff, usePeriodic, true, force.getCutoffDistance(), exclusionList, source, force.getForceGroup());
1945

1946
    // Initialize the exceptions.
1947

1948
1949
1950
1951
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*exceptions.size()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*exceptions.size()/numContexts;
    int numExceptions = endIndex-startIndex;
1952
    if (numExceptions > 0) {
1953
        paramsDefines["HAS_EXCEPTIONS"] = "1";
1954
        exceptionAtoms.resize(numExceptions);
Peter Eastman's avatar
Peter Eastman committed
1955
        vector<vector<int> > atoms(numExceptions, vector<int>(2));
peastman's avatar
peastman committed
1956
        exceptionParams.initialize<mm_float4>(cl, numExceptions, "exceptionParams");
1957
1958
        baseExceptionParams.initialize<mm_float4>(cl, numExceptions, "baseExceptionParams");
        vector<mm_float4> baseExceptionParamsVec(numExceptions);
1959
        for (int i = 0; i < numExceptions; i++) {
1960
            double chargeProd, sigma, epsilon;
Peter Eastman's avatar
Peter Eastman committed
1961
            force.getExceptionParameters(exceptions[startIndex+i], atoms[i][0], atoms[i][1], chargeProd, sigma, epsilon);
1962
            baseExceptionParamsVec[i] = mm_float4(chargeProd, sigma, epsilon, 0);
1963
            exceptionAtoms[i] = make_pair(atoms[i][0], atoms[i][1]);
1964
        }
1965
        baseExceptionParams.upload(baseExceptionParamsVec);
Peter Eastman's avatar
Peter Eastman committed
1966
        map<string, string> replacements;
peastman's avatar
peastman committed
1967
        replacements["PARAMS"] = cl.getBondedUtilities().addArgument(exceptionParams.getDeviceBuffer(), "float4");
1968
        cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::nonbondedExceptions, replacements), force.getForceGroup());
Peter Eastman's avatar
Peter Eastman committed
1969
    }
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
    
    // Initialize parameter offsets.

    vector<vector<mm_float4> > particleOffsetVec(force.getNumParticles());
    vector<vector<mm_float4> > exceptionOffsetVec(force.getNumExceptions());
    for (int i = 0; i < force.getNumParticleParameterOffsets(); i++) {
        string param;
        int particle;
        double charge, sigma, epsilon;
        force.getParticleParameterOffset(i, param, particle, charge, sigma, epsilon);
        auto paramPos = find(paramNames.begin(), paramNames.end(), param);
        int paramIndex;
        if (paramPos == paramNames.end()) {
            paramIndex = paramNames.size();
            paramNames.push_back(param);
        }
        else
            paramIndex = paramPos-paramNames.begin();
        particleOffsetVec[particle].push_back(mm_float4(charge, sigma, epsilon, paramIndex));
    }
    for (int i = 0; i < force.getNumExceptionParameterOffsets(); i++) {
        string param;
        int exception;
        double charge, sigma, epsilon;
        force.getExceptionParameterOffset(i, param, exception, charge, sigma, epsilon);
        auto paramPos = find(paramNames.begin(), paramNames.end(), param);
        int paramIndex;
        if (paramPos == paramNames.end()) {
            paramIndex = paramNames.size();
            paramNames.push_back(param);
        }
        else
            paramIndex = paramPos-paramNames.begin();
2003
        exceptionOffsetVec[exceptionIndex[exception]].push_back(mm_float4(charge, sigma, epsilon, paramIndex));
2004
2005
2006
2007
    }
    paramValues.resize(paramNames.size(), 0.0);
    particleParamOffsets.initialize<mm_float4>(cl, max(force.getNumParticleParameterOffsets(), 1), "particleParamOffsets");
    exceptionParamOffsets.initialize<mm_float4>(cl, max(force.getNumExceptionParameterOffsets(), 1), "exceptionParamOffsets");
Peter Eastman's avatar
Bug fix  
Peter Eastman committed
2008
    particleOffsetIndices.initialize<cl_int>(cl, cl.getPaddedNumAtoms()+1, "particleOffsetIndices");
2009
2010
2011
2012
2013
2014
2015
2016
    exceptionOffsetIndices.initialize<cl_int>(cl, force.getNumExceptions()+1, "exceptionOffsetIndices");
    vector<cl_int> particleOffsetIndicesVec, exceptionOffsetIndicesVec;
    vector<mm_float4> p, e;
    for (int i = 0; i < particleOffsetVec.size(); i++) {
        particleOffsetIndicesVec.push_back(p.size());
        for (int j = 0; j < particleOffsetVec[i].size(); j++)
            p.push_back(particleOffsetVec[i][j]);
    }
Peter Eastman's avatar
Bug fix  
Peter Eastman committed
2017
2018
    while (particleOffsetIndicesVec.size() < particleOffsetIndices.getSize())
        particleOffsetIndicesVec.push_back(p.size());
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
    for (int i = 0; i < exceptionOffsetVec.size(); i++) {
        exceptionOffsetIndicesVec.push_back(e.size());
        for (int j = 0; j < exceptionOffsetVec[i].size(); j++)
            e.push_back(exceptionOffsetVec[i][j]);
    }
    exceptionOffsetIndicesVec.push_back(e.size());
    if (force.getNumParticleParameterOffsets() > 0) {
        particleParamOffsets.upload(p);
        particleOffsetIndices.upload(particleOffsetIndicesVec);
    }
    if (force.getNumExceptionParameterOffsets() > 0) {
        exceptionParamOffsets.upload(e);
        exceptionOffsetIndices.upload(exceptionOffsetIndicesVec);
    }
    globalParams.initialize(cl, max((int) paramValues.size(), 1), cl.getUseDoublePrecision() ? sizeof(double) : sizeof(float), "globalParams");
Peter Eastman's avatar
Peter Eastman committed
2034
    recomputeParams = true;
2035
2036
2037
2038
2039
    
    // Initialize the kernel for updating parameters.
    
    cl::Program program = cl.createProgram(OpenCLKernelSources::nonbondedParameters, paramsDefines);
    computeParamsKernel = cl::Kernel(program, "computeParameters");
Peter Eastman's avatar
Peter Eastman committed
2040
    computeExclusionParamsKernel = cl::Kernel(program, "computeExclusionParameters");
2041
2042
    info = new ForceInfo(cl.getNonbondedUtilities().getNumForceBuffers(), force);
    cl.addForce(info);
2043
2044
}

2045
double OpenCLCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy, bool includeDirect, bool includeReciprocal) {
2046
    bool deviceIsCpu = (cl.getDevice().getInfo<CL_DEVICE_TYPE>() == CL_DEVICE_TYPE_CPU);
2047
2048
    if (!hasInitializedKernel) {
        hasInitializedKernel = true;
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
        int index = 0;
        computeParamsKernel.setArg<cl::Buffer>(index++, cl.getEnergyBuffer().getDeviceBuffer());
        index++;
        computeParamsKernel.setArg<cl::Buffer>(index++, globalParams.getDeviceBuffer());
        computeParamsKernel.setArg<cl_int>(index++, cl.getPaddedNumAtoms());
        computeParamsKernel.setArg<cl::Buffer>(index++, baseParticleParams.getDeviceBuffer());
        computeParamsKernel.setArg<cl::Buffer>(index++, cl.getPosq().getDeviceBuffer());
        computeParamsKernel.setArg<cl::Buffer>(index++, charges.getDeviceBuffer());
        computeParamsKernel.setArg<cl::Buffer>(index++, sigmaEpsilon.getDeviceBuffer());
        computeParamsKernel.setArg<cl::Buffer>(index++, particleParamOffsets.getDeviceBuffer());
        computeParamsKernel.setArg<cl::Buffer>(index++, particleOffsetIndices.getDeviceBuffer());
2060
        if (exceptionParams.isInitialized()) {
2061
2062
2063
2064
2065
2066
2067
            computeParamsKernel.setArg<cl_int>(index++, exceptionParams.getSize());
            computeParamsKernel.setArg<cl::Buffer>(index++, baseExceptionParams.getDeviceBuffer());
            computeParamsKernel.setArg<cl::Buffer>(index++, exceptionParams.getDeviceBuffer());
            computeParamsKernel.setArg<cl::Buffer>(index++, exceptionParamOffsets.getDeviceBuffer());
            computeParamsKernel.setArg<cl::Buffer>(index++, exceptionOffsetIndices.getDeviceBuffer());
        }
        if (exclusionParams.isInitialized()) {
Peter Eastman's avatar
Peter Eastman committed
2068
2069
2070
2071
2072
2073
            computeExclusionParamsKernel.setArg<cl::Buffer>(0, cl.getPosq().getDeviceBuffer());
            computeExclusionParamsKernel.setArg<cl::Buffer>(1, charges.getDeviceBuffer());
            computeExclusionParamsKernel.setArg<cl::Buffer>(2, sigmaEpsilon.getDeviceBuffer());
            computeExclusionParamsKernel.setArg<cl_int>(3, exclusionParams.getSize());
            computeExclusionParamsKernel.setArg<cl::Buffer>(4, exclusionAtoms.getDeviceBuffer());
            computeExclusionParamsKernel.setArg<cl::Buffer>(5, exclusionParams.getDeviceBuffer());
2074
        }
peastman's avatar
peastman committed
2075
        if (cosSinSums.isInitialized()) {
2076
2077
            ewaldSumsKernel.setArg<cl::Buffer>(0, cl.getEnergyBuffer().getDeviceBuffer());
            ewaldSumsKernel.setArg<cl::Buffer>(1, cl.getPosq().getDeviceBuffer());
peastman's avatar
peastman committed
2078
            ewaldSumsKernel.setArg<cl::Buffer>(2, cosSinSums.getDeviceBuffer());
2079
2080
            ewaldForcesKernel.setArg<cl::Buffer>(0, cl.getForceBuffers().getDeviceBuffer());
            ewaldForcesKernel.setArg<cl::Buffer>(1, cl.getPosq().getDeviceBuffer());
peastman's avatar
peastman committed
2081
            ewaldForcesKernel.setArg<cl::Buffer>(2, cosSinSums.getDeviceBuffer());
2082
        }
Peter Eastman's avatar
Peter Eastman committed
2083
        if (pmeGrid1.isInitialized()) {
2084
2085
            // Create kernels for Coulomb PME.
            
2086
2087
2088
            map<string, string> replacements;
            replacements["CHARGE"] = (usePosqCharges ? "pos.w" : "charges[atom]");
            cl::Program program = cl.createProgram(cl.replaceStrings(OpenCLKernelSources::pme, replacements), pmeDefines);
2089
            pmeUpdateBsplinesKernel = cl::Kernel(program, "updateBsplines");
2090
            pmeAtomRangeKernel = cl::Kernel(program, "findAtomRangeForGrid");
2091
            pmeZIndexKernel = cl::Kernel(program, "recordZIndex");
2092
2093
            pmeSpreadChargeKernel = cl::Kernel(program, "gridSpreadCharge");
            pmeConvolutionKernel = cl::Kernel(program, "reciprocalConvolution");
2094
            pmeEvalEnergyKernel = cl::Kernel(program, "gridEvaluateEnergy");
2095
            pmeInterpolateForceKernel = cl::Kernel(program, "gridInterpolateForce");
2096
            int elementSize = (cl.getUseDoublePrecision() ? sizeof(mm_double4) : sizeof(mm_float4));
2097
            pmeUpdateBsplinesKernel.setArg<cl::Buffer>(0, cl.getPosq().getDeviceBuffer());
peastman's avatar
peastman committed
2098
            pmeUpdateBsplinesKernel.setArg<cl::Buffer>(1, pmeBsplineTheta.getDeviceBuffer());
2099
            pmeUpdateBsplinesKernel.setArg(2, OpenCLContext::ThreadBlockSize*PmeOrder*elementSize, NULL);
peastman's avatar
peastman committed
2100
            pmeUpdateBsplinesKernel.setArg<cl::Buffer>(3, pmeAtomGridIndex.getDeviceBuffer());
2101
            pmeUpdateBsplinesKernel.setArg<cl::Buffer>(12, charges.getDeviceBuffer());
peastman's avatar
peastman committed
2102
2103
            pmeAtomRangeKernel.setArg<cl::Buffer>(0, pmeAtomGridIndex.getDeviceBuffer());
            pmeAtomRangeKernel.setArg<cl::Buffer>(1, pmeAtomRange.getDeviceBuffer());
2104
            pmeAtomRangeKernel.setArg<cl::Buffer>(2, cl.getPosq().getDeviceBuffer());
peastman's avatar
peastman committed
2105
            pmeZIndexKernel.setArg<cl::Buffer>(0, pmeAtomGridIndex.getDeviceBuffer());
2106
            pmeZIndexKernel.setArg<cl::Buffer>(1, cl.getPosq().getDeviceBuffer());
2107
            pmeSpreadChargeKernel.setArg<cl::Buffer>(0, cl.getPosq().getDeviceBuffer());
peastman's avatar
peastman committed
2108
2109
            pmeSpreadChargeKernel.setArg<cl::Buffer>(1, pmeAtomGridIndex.getDeviceBuffer());
            pmeSpreadChargeKernel.setArg<cl::Buffer>(2, pmeAtomRange.getDeviceBuffer());
peastman's avatar
peastman committed
2110
            if (cl.getSupports64BitGlobalAtomics())
peastman's avatar
peastman committed
2111
                pmeSpreadChargeKernel.setArg<cl::Buffer>(3, pmeGrid2.getDeviceBuffer());
peastman's avatar
peastman committed
2112
            else
Peter Eastman's avatar
Peter Eastman committed
2113
                pmeSpreadChargeKernel.setArg<cl::Buffer>(3, pmeGrid1.getDeviceBuffer());
peastman's avatar
peastman committed
2114
            pmeSpreadChargeKernel.setArg<cl::Buffer>(4, pmeBsplineTheta.getDeviceBuffer());
2115
2116
2117
2118
            if (deviceIsCpu || cl.getSupports64BitGlobalAtomics())
                pmeSpreadChargeKernel.setArg<cl::Buffer>(13, charges.getDeviceBuffer());
            else
                pmeSpreadChargeKernel.setArg<cl::Buffer>(5, charges.getDeviceBuffer());
peastman's avatar
peastman committed
2119
2120
2121
2122
2123
2124
2125
2126
2127
            pmeConvolutionKernel.setArg<cl::Buffer>(0, pmeGrid2.getDeviceBuffer());
            pmeConvolutionKernel.setArg<cl::Buffer>(1, pmeBsplineModuliX.getDeviceBuffer());
            pmeConvolutionKernel.setArg<cl::Buffer>(2, pmeBsplineModuliY.getDeviceBuffer());
            pmeConvolutionKernel.setArg<cl::Buffer>(3, pmeBsplineModuliZ.getDeviceBuffer());
            pmeEvalEnergyKernel.setArg<cl::Buffer>(0, pmeGrid2.getDeviceBuffer());
            pmeEvalEnergyKernel.setArg<cl::Buffer>(1, usePmeQueue ? pmeEnergyBuffer.getDeviceBuffer() : cl.getEnergyBuffer().getDeviceBuffer());
            pmeEvalEnergyKernel.setArg<cl::Buffer>(2, pmeBsplineModuliX.getDeviceBuffer());
            pmeEvalEnergyKernel.setArg<cl::Buffer>(3, pmeBsplineModuliY.getDeviceBuffer());
            pmeEvalEnergyKernel.setArg<cl::Buffer>(4, pmeBsplineModuliZ.getDeviceBuffer());
2128
2129
            pmeInterpolateForceKernel.setArg<cl::Buffer>(0, cl.getPosq().getDeviceBuffer());
            pmeInterpolateForceKernel.setArg<cl::Buffer>(1, cl.getForceBuffers().getDeviceBuffer());
Peter Eastman's avatar
Peter Eastman committed
2130
            pmeInterpolateForceKernel.setArg<cl::Buffer>(2, pmeGrid1.getDeviceBuffer());
peastman's avatar
peastman committed
2131
            pmeInterpolateForceKernel.setArg<cl::Buffer>(11, pmeAtomGridIndex.getDeviceBuffer());
2132
            pmeInterpolateForceKernel.setArg<cl::Buffer>(12, charges.getDeviceBuffer());
2133
2134
            if (cl.getSupports64BitGlobalAtomics()) {
                pmeFinishSpreadChargeKernel = cl::Kernel(program, "finishSpreadCharge");
peastman's avatar
peastman committed
2135
                pmeFinishSpreadChargeKernel.setArg<cl::Buffer>(0, pmeGrid2.getDeviceBuffer());
Peter Eastman's avatar
Peter Eastman committed
2136
                pmeFinishSpreadChargeKernel.setArg<cl::Buffer>(1, pmeGrid1.getDeviceBuffer());
2137
            }
2138
2139
            if (usePmeQueue)
                syncQueue->setKernel(cl::Kernel(program, "addEnergy"));
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160

            if (doLJPME) {
                // Create kernels for LJ PME.

                pmeDefines["EWALD_ALPHA"] = cl.doubleToString(dispersionAlpha);
                pmeDefines["GRID_SIZE_X"] = cl.intToString(dispersionGridSizeX);
                pmeDefines["GRID_SIZE_Y"] = cl.intToString(dispersionGridSizeY);
                pmeDefines["GRID_SIZE_Z"] = cl.intToString(dispersionGridSizeZ);
                pmeDefines["EPSILON_FACTOR"] = "1";
                pmeDefines["RECIP_EXP_FACTOR"] = cl.doubleToString(M_PI*M_PI/(dispersionAlpha*dispersionAlpha));
                pmeDefines["USE_LJPME"] = "1";
                program = cl.createProgram(OpenCLKernelSources::pme, pmeDefines);
                pmeDispersionUpdateBsplinesKernel = cl::Kernel(program, "updateBsplines");
                pmeDispersionAtomRangeKernel = cl::Kernel(program, "findAtomRangeForGrid");
                pmeDispersionZIndexKernel = cl::Kernel(program, "recordZIndex");
                pmeDispersionSpreadChargeKernel = cl::Kernel(program, "gridSpreadCharge");
                pmeDispersionConvolutionKernel = cl::Kernel(program, "reciprocalConvolution");
                pmeDispersionEvalEnergyKernel = cl::Kernel(program, "gridEvaluateEnergy");
                pmeDispersionInterpolateForceKernel = cl::Kernel(program, "gridInterpolateForce");
                int elementSize = (cl.getUseDoublePrecision() ? sizeof(mm_double4) : sizeof(mm_float4));
                pmeDispersionUpdateBsplinesKernel.setArg<cl::Buffer>(0, cl.getPosq().getDeviceBuffer());
peastman's avatar
peastman committed
2161
                pmeDispersionUpdateBsplinesKernel.setArg<cl::Buffer>(1, pmeBsplineTheta.getDeviceBuffer());
2162
                pmeDispersionUpdateBsplinesKernel.setArg(2, OpenCLContext::ThreadBlockSize*PmeOrder*elementSize, NULL);
peastman's avatar
peastman committed
2163
2164
2165
2166
                pmeDispersionUpdateBsplinesKernel.setArg<cl::Buffer>(3, pmeAtomGridIndex.getDeviceBuffer());
                pmeDispersionUpdateBsplinesKernel.setArg<cl::Buffer>(12, sigmaEpsilon.getDeviceBuffer());
                pmeDispersionAtomRangeKernel.setArg<cl::Buffer>(0, pmeAtomGridIndex.getDeviceBuffer());
                pmeDispersionAtomRangeKernel.setArg<cl::Buffer>(1, pmeAtomRange.getDeviceBuffer());
2167
                pmeDispersionAtomRangeKernel.setArg<cl::Buffer>(2, cl.getPosq().getDeviceBuffer());
peastman's avatar
peastman committed
2168
                pmeDispersionZIndexKernel.setArg<cl::Buffer>(0, pmeAtomGridIndex.getDeviceBuffer());
2169
2170
                pmeDispersionZIndexKernel.setArg<cl::Buffer>(1, cl.getPosq().getDeviceBuffer());
                pmeDispersionSpreadChargeKernel.setArg<cl::Buffer>(0, cl.getPosq().getDeviceBuffer());
peastman's avatar
peastman committed
2171
2172
                pmeDispersionSpreadChargeKernel.setArg<cl::Buffer>(1, pmeAtomGridIndex.getDeviceBuffer());
                pmeDispersionSpreadChargeKernel.setArg<cl::Buffer>(2, pmeAtomRange.getDeviceBuffer());
2173
                if (cl.getSupports64BitGlobalAtomics())
peastman's avatar
peastman committed
2174
                    pmeDispersionSpreadChargeKernel.setArg<cl::Buffer>(3, pmeGrid2.getDeviceBuffer());
2175
                else
Peter Eastman's avatar
Peter Eastman committed
2176
                    pmeDispersionSpreadChargeKernel.setArg<cl::Buffer>(3, pmeGrid1.getDeviceBuffer());
peastman's avatar
peastman committed
2177
                pmeDispersionSpreadChargeKernel.setArg<cl::Buffer>(4, pmeBsplineTheta.getDeviceBuffer());
peastman's avatar
peastman committed
2178
                if (deviceIsCpu || cl.getSupports64BitGlobalAtomics())
peastman's avatar
peastman committed
2179
                    pmeDispersionSpreadChargeKernel.setArg<cl::Buffer>(13, sigmaEpsilon.getDeviceBuffer());
peastman's avatar
peastman committed
2180
                else
peastman's avatar
peastman committed
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
                    pmeDispersionSpreadChargeKernel.setArg<cl::Buffer>(5, sigmaEpsilon.getDeviceBuffer());
                pmeDispersionConvolutionKernel.setArg<cl::Buffer>(0, pmeGrid2.getDeviceBuffer());
                pmeDispersionConvolutionKernel.setArg<cl::Buffer>(1, pmeDispersionBsplineModuliX.getDeviceBuffer());
                pmeDispersionConvolutionKernel.setArg<cl::Buffer>(2, pmeDispersionBsplineModuliY.getDeviceBuffer());
                pmeDispersionConvolutionKernel.setArg<cl::Buffer>(3, pmeDispersionBsplineModuliZ.getDeviceBuffer());
                pmeDispersionEvalEnergyKernel.setArg<cl::Buffer>(0, pmeGrid2.getDeviceBuffer());
                pmeDispersionEvalEnergyKernel.setArg<cl::Buffer>(1, usePmeQueue ? pmeEnergyBuffer.getDeviceBuffer() : cl.getEnergyBuffer().getDeviceBuffer());
                pmeDispersionEvalEnergyKernel.setArg<cl::Buffer>(2, pmeDispersionBsplineModuliX.getDeviceBuffer());
                pmeDispersionEvalEnergyKernel.setArg<cl::Buffer>(3, pmeDispersionBsplineModuliY.getDeviceBuffer());
                pmeDispersionEvalEnergyKernel.setArg<cl::Buffer>(4, pmeDispersionBsplineModuliZ.getDeviceBuffer());
2191
2192
                pmeDispersionInterpolateForceKernel.setArg<cl::Buffer>(0, cl.getPosq().getDeviceBuffer());
                pmeDispersionInterpolateForceKernel.setArg<cl::Buffer>(1, cl.getForceBuffers().getDeviceBuffer());
Peter Eastman's avatar
Peter Eastman committed
2193
                pmeDispersionInterpolateForceKernel.setArg<cl::Buffer>(2, pmeGrid1.getDeviceBuffer());
peastman's avatar
peastman committed
2194
2195
                pmeDispersionInterpolateForceKernel.setArg<cl::Buffer>(11, pmeAtomGridIndex.getDeviceBuffer());
                pmeDispersionInterpolateForceKernel.setArg<cl::Buffer>(12, sigmaEpsilon.getDeviceBuffer());
2196
2197
                if (cl.getSupports64BitGlobalAtomics()) {
                    pmeDispersionFinishSpreadChargeKernel = cl::Kernel(program, "finishSpreadCharge");
peastman's avatar
peastman committed
2198
                    pmeDispersionFinishSpreadChargeKernel.setArg<cl::Buffer>(0, pmeGrid2.getDeviceBuffer());
Peter Eastman's avatar
Peter Eastman committed
2199
                    pmeDispersionFinishSpreadChargeKernel.setArg<cl::Buffer>(1, pmeGrid1.getDeviceBuffer());
2200
2201
                }
            }
2202
       }
2203
    }
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
    
    // Update particle and exception parameters.

    bool paramChanged = false;
    for (int i = 0; i < paramNames.size(); i++) {
        double value = context.getParameter(paramNames[i]);
        if (value != paramValues[i]) {
            paramValues[i] = value;;
            paramChanged = true;
        }
    }
    if (paramChanged) {
Peter Eastman's avatar
Peter Eastman committed
2216
        recomputeParams = true;
2217
        globalParams.upload(paramValues, true, true);
2218
    }
Peter Eastman's avatar
Peter Eastman committed
2219
2220
2221
2222
    double energy = (includeReciprocal ? ewaldSelfEnergy : 0.0);
    if (recomputeParams || hasOffsets) {
        computeParamsKernel.setArg<cl_int>(1, includeEnergy && includeReciprocal);
        cl.executeKernel(computeParamsKernel, cl.getPaddedNumAtoms());
Peter Eastman's avatar
Peter Eastman committed
2223
2224
        if (exclusionParams.isInitialized())
            cl.executeKernel(computeExclusionParamsKernel, exclusionParams.getSize());
2225
2226
2227
2228
2229
        if (usePmeQueue) {
            vector<cl::Event> events(1);
            cl.getQueue().enqueueMarker(&events[0]);
            pmeQueue.enqueueWaitForEvents(events);
        }
2230
2231
        if (hasOffsets)
            energy = 0.0; // The Ewald self energy was computed in the kernel.
2232
        recomputeParams = false;
Peter Eastman's avatar
Peter Eastman committed
2233
    }
2234
2235
2236
    
    // Do reciprocal space calculations.
    
peastman's avatar
peastman committed
2237
    if (cosSinSums.isInitialized() && includeReciprocal) {
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
        mm_double4 boxSize = cl.getPeriodicBoxSizeDouble();
        mm_double4 recipBoxSize = mm_double4(2*M_PI/boxSize.x, 2*M_PI/boxSize.y, 2*M_PI/boxSize.z, 0.0);
        double recipCoefficient = ONE_4PI_EPS0*4*M_PI/(boxSize.x*boxSize.y*boxSize.z);
        if (cl.getUseDoublePrecision()) {
            ewaldSumsKernel.setArg<mm_double4>(3, recipBoxSize);
            ewaldSumsKernel.setArg<cl_double>(4, recipCoefficient);
            ewaldForcesKernel.setArg<mm_double4>(3, recipBoxSize);
            ewaldForcesKernel.setArg<cl_double>(4, recipCoefficient);
        }
        else {
            ewaldSumsKernel.setArg<mm_float4>(3, mm_float4((float) recipBoxSize.x, (float) recipBoxSize.y, (float) recipBoxSize.z, 0));
            ewaldSumsKernel.setArg<cl_float>(4, (cl_float) recipCoefficient);
            ewaldForcesKernel.setArg<mm_float4>(3, mm_float4((float) recipBoxSize.x, (float) recipBoxSize.y, (float) recipBoxSize.z, 0));
            ewaldForcesKernel.setArg<cl_float>(4, (cl_float) recipCoefficient);
        }
peastman's avatar
peastman committed
2253
        cl.executeKernel(ewaldSumsKernel, cosSinSums.getSize());
2254
2255
        cl.executeKernel(ewaldForcesKernel, cl.getNumAtoms());
    }
Peter Eastman's avatar
Peter Eastman committed
2256
    if (pmeGrid1.isInitialized() && includeReciprocal) {
2257
        if (usePmeQueue && !includeEnergy)
2258
            cl.setQueue(pmeQueue);
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
        
        // Invert the periodic box vectors.
        
        Vec3 boxVectors[3];
        cl.getPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
        double determinant = boxVectors[0][0]*boxVectors[1][1]*boxVectors[2][2];
        double scale = 1.0/determinant;
        mm_double4 recipBoxVectors[3];
        recipBoxVectors[0] = mm_double4(boxVectors[1][1]*boxVectors[2][2]*scale, 0, 0, 0);
        recipBoxVectors[1] = mm_double4(-boxVectors[1][0]*boxVectors[2][2]*scale, boxVectors[0][0]*boxVectors[2][2]*scale, 0, 0);
        recipBoxVectors[2] = mm_double4((boxVectors[1][0]*boxVectors[2][1]-boxVectors[1][1]*boxVectors[2][0])*scale, -boxVectors[0][0]*boxVectors[2][1]*scale, boxVectors[0][0]*boxVectors[1][1]*scale, 0);
        mm_float4 recipBoxVectorsFloat[3];
        for (int i = 0; i < 3; i++)
            recipBoxVectorsFloat[i] = mm_float4((float) recipBoxVectors[i].x, (float) recipBoxVectors[i].y, (float) recipBoxVectors[i].z, 0);
        
        // Execute the reciprocal space kernels.

peastman's avatar
peastman committed
2276
2277
        if (hasCoulomb) {
            setPeriodicBoxArgs(cl, pmeUpdateBsplinesKernel, 4);
2278
            if (cl.getUseDoublePrecision()) {
peastman's avatar
peastman committed
2279
2280
2281
                pmeUpdateBsplinesKernel.setArg<mm_double4>(9, recipBoxVectors[0]);
                pmeUpdateBsplinesKernel.setArg<mm_double4>(10, recipBoxVectors[1]);
                pmeUpdateBsplinesKernel.setArg<mm_double4>(11, recipBoxVectors[2]);
2282
2283
            }
            else {
peastman's avatar
peastman committed
2284
2285
2286
                pmeUpdateBsplinesKernel.setArg<mm_float4>(9, recipBoxVectorsFloat[0]);
                pmeUpdateBsplinesKernel.setArg<mm_float4>(10, recipBoxVectorsFloat[1]);
                pmeUpdateBsplinesKernel.setArg<mm_float4>(11, recipBoxVectorsFloat[2]);
2287
            }
peastman's avatar
peastman committed
2288
2289
            cl.executeKernel(pmeUpdateBsplinesKernel, cl.getNumAtoms());
            if (deviceIsCpu && !cl.getSupports64BitGlobalAtomics()) {
2290
                setPeriodicBoxArgs(cl, pmeSpreadChargeKernel, 5);
2291
                if (cl.getUseDoublePrecision()) {
2292
2293
2294
                    pmeSpreadChargeKernel.setArg<mm_double4>(10, recipBoxVectors[0]);
                    pmeSpreadChargeKernel.setArg<mm_double4>(11, recipBoxVectors[1]);
                    pmeSpreadChargeKernel.setArg<mm_double4>(12, recipBoxVectors[2]);
2295
2296
                }
                else {
2297
2298
2299
                    pmeSpreadChargeKernel.setArg<mm_float4>(10, recipBoxVectorsFloat[0]);
                    pmeSpreadChargeKernel.setArg<mm_float4>(11, recipBoxVectorsFloat[1]);
                    pmeSpreadChargeKernel.setArg<mm_float4>(12, recipBoxVectorsFloat[2]);
2300
                }
peastman's avatar
peastman committed
2301
                cl.executeKernel(pmeSpreadChargeKernel, 2*cl.getDevice().getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>(), 1);
2302
            }
2303
            else {
peastman's avatar
peastman committed
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
                sort->sort(pmeAtomGridIndex);
                if (cl.getSupports64BitGlobalAtomics()) {
                    setPeriodicBoxArgs(cl, pmeSpreadChargeKernel, 5);
                    if (cl.getUseDoublePrecision()) {
                        pmeSpreadChargeKernel.setArg<mm_double4>(10, recipBoxVectors[0]);
                        pmeSpreadChargeKernel.setArg<mm_double4>(11, recipBoxVectors[1]);
                        pmeSpreadChargeKernel.setArg<mm_double4>(12, recipBoxVectors[2]);
                    }
                    else {
                        pmeSpreadChargeKernel.setArg<mm_float4>(10, recipBoxVectorsFloat[0]);
                        pmeSpreadChargeKernel.setArg<mm_float4>(11, recipBoxVectorsFloat[1]);
                        pmeSpreadChargeKernel.setArg<mm_float4>(12, recipBoxVectorsFloat[2]);
                    }
                    cl.executeKernel(pmeSpreadChargeKernel, cl.getNumAtoms());
                    cl.executeKernel(pmeFinishSpreadChargeKernel, gridSizeX*gridSizeY*gridSizeZ);
                }
                else {
                    cl.executeKernel(pmeAtomRangeKernel, cl.getNumAtoms());
                    setPeriodicBoxSizeArg(cl, pmeZIndexKernel, 2);
                    if (cl.getUseDoublePrecision())
                        pmeZIndexKernel.setArg<mm_double4>(3, recipBoxVectors[2]);
                    else
                        pmeZIndexKernel.setArg<mm_float4>(3, recipBoxVectorsFloat[2]);
                    cl.executeKernel(pmeZIndexKernel, cl.getNumAtoms());
                    cl.executeKernel(pmeSpreadChargeKernel, cl.getNumAtoms());
                }
2330
            }
Peter Eastman's avatar
Peter Eastman committed
2331
            fft->execFFT(pmeGrid1, pmeGrid2, true);
peastman's avatar
peastman committed
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
            mm_double4 boxSize = cl.getPeriodicBoxSizeDouble();
            if (cl.getUseDoublePrecision()) {
                pmeConvolutionKernel.setArg<mm_double4>(4, recipBoxVectors[0]);
                pmeConvolutionKernel.setArg<mm_double4>(5, recipBoxVectors[1]);
                pmeConvolutionKernel.setArg<mm_double4>(6, recipBoxVectors[2]);
                pmeEvalEnergyKernel.setArg<mm_double4>(5, recipBoxVectors[0]);
                pmeEvalEnergyKernel.setArg<mm_double4>(6, recipBoxVectors[1]);
                pmeEvalEnergyKernel.setArg<mm_double4>(7, recipBoxVectors[2]);
            }
            else {
                pmeConvolutionKernel.setArg<mm_float4>(4, recipBoxVectorsFloat[0]);
                pmeConvolutionKernel.setArg<mm_float4>(5, recipBoxVectorsFloat[1]);
                pmeConvolutionKernel.setArg<mm_float4>(6, recipBoxVectorsFloat[2]);
                pmeEvalEnergyKernel.setArg<mm_float4>(5, recipBoxVectorsFloat[0]);
                pmeEvalEnergyKernel.setArg<mm_float4>(6, recipBoxVectorsFloat[1]);
                pmeEvalEnergyKernel.setArg<mm_float4>(7, recipBoxVectorsFloat[2]);
            }
            if (includeEnergy)
                cl.executeKernel(pmeEvalEnergyKernel, gridSizeX*gridSizeY*gridSizeZ);
            cl.executeKernel(pmeConvolutionKernel, gridSizeX*gridSizeY*gridSizeZ);
Peter Eastman's avatar
Peter Eastman committed
2352
            fft->execFFT(pmeGrid2, pmeGrid1, false);
peastman's avatar
peastman committed
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
            setPeriodicBoxArgs(cl, pmeInterpolateForceKernel, 3);
            if (cl.getUseDoublePrecision()) {
                pmeInterpolateForceKernel.setArg<mm_double4>(8, recipBoxVectors[0]);
                pmeInterpolateForceKernel.setArg<mm_double4>(9, recipBoxVectors[1]);
                pmeInterpolateForceKernel.setArg<mm_double4>(10, recipBoxVectors[2]);
            }
            else {
                pmeInterpolateForceKernel.setArg<mm_float4>(8, recipBoxVectorsFloat[0]);
                pmeInterpolateForceKernel.setArg<mm_float4>(9, recipBoxVectorsFloat[1]);
                pmeInterpolateForceKernel.setArg<mm_float4>(10, recipBoxVectorsFloat[2]);
            }
            if (deviceIsCpu)
                cl.executeKernel(pmeInterpolateForceKernel, 2*cl.getDevice().getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>(), 1);
            else
                cl.executeKernel(pmeInterpolateForceKernel, cl.getNumAtoms());
2368
        }
2369
        
peastman's avatar
peastman committed
2370
        if (doLJPME && hasLJ) {
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
            setPeriodicBoxArgs(cl, pmeDispersionUpdateBsplinesKernel, 4);
            if (cl.getUseDoublePrecision()) {
                pmeDispersionUpdateBsplinesKernel.setArg<mm_double4>(9, recipBoxVectors[0]);
                pmeDispersionUpdateBsplinesKernel.setArg<mm_double4>(10, recipBoxVectors[1]);
                pmeDispersionUpdateBsplinesKernel.setArg<mm_double4>(11, recipBoxVectors[2]);
            }
            else {
                pmeDispersionUpdateBsplinesKernel.setArg<mm_float4>(9, recipBoxVectorsFloat[0]);
                pmeDispersionUpdateBsplinesKernel.setArg<mm_float4>(10, recipBoxVectorsFloat[1]);
                pmeDispersionUpdateBsplinesKernel.setArg<mm_float4>(11, recipBoxVectorsFloat[2]);
            }
            cl.executeKernel(pmeDispersionUpdateBsplinesKernel, cl.getNumAtoms());
            if (deviceIsCpu && !cl.getSupports64BitGlobalAtomics()) {
Peter Eastman's avatar
Peter Eastman committed
2384
                cl.clearBuffer(pmeGrid1);
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
                setPeriodicBoxArgs(cl, pmeDispersionSpreadChargeKernel, 5);
                if (cl.getUseDoublePrecision()) {
                    pmeDispersionSpreadChargeKernel.setArg<mm_double4>(10, recipBoxVectors[0]);
                    pmeDispersionSpreadChargeKernel.setArg<mm_double4>(11, recipBoxVectors[1]);
                    pmeDispersionSpreadChargeKernel.setArg<mm_double4>(12, recipBoxVectors[2]);
                }
                else {
                    pmeDispersionSpreadChargeKernel.setArg<mm_float4>(10, recipBoxVectorsFloat[0]);
                    pmeDispersionSpreadChargeKernel.setArg<mm_float4>(11, recipBoxVectorsFloat[1]);
                    pmeDispersionSpreadChargeKernel.setArg<mm_float4>(12, recipBoxVectorsFloat[2]);
                }
                cl.executeKernel(pmeDispersionSpreadChargeKernel, 2*cl.getDevice().getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>(), 1);
            }
            else {
                if (cl.getSupports64BitGlobalAtomics()) {
2400
2401
                    if (!hasCoulomb)
                        sort->sort(pmeAtomGridIndex);
peastman's avatar
peastman committed
2402
                    cl.clearBuffer(pmeGrid2);
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
                    setPeriodicBoxArgs(cl, pmeDispersionSpreadChargeKernel, 5);
                    if (cl.getUseDoublePrecision()) {
                        pmeDispersionSpreadChargeKernel.setArg<mm_double4>(10, recipBoxVectors[0]);
                        pmeDispersionSpreadChargeKernel.setArg<mm_double4>(11, recipBoxVectors[1]);
                        pmeDispersionSpreadChargeKernel.setArg<mm_double4>(12, recipBoxVectors[2]);
                    }
                    else {
                        pmeDispersionSpreadChargeKernel.setArg<mm_float4>(10, recipBoxVectorsFloat[0]);
                        pmeDispersionSpreadChargeKernel.setArg<mm_float4>(11, recipBoxVectorsFloat[1]);
                        pmeDispersionSpreadChargeKernel.setArg<mm_float4>(12, recipBoxVectorsFloat[2]);
                    }
                    cl.executeKernel(pmeDispersionSpreadChargeKernel, cl.getNumAtoms());
                    cl.executeKernel(pmeDispersionFinishSpreadChargeKernel, gridSizeX*gridSizeY*gridSizeZ);
                }
                else {
2418
                    sort->sort(pmeAtomGridIndex);
Peter Eastman's avatar
Peter Eastman committed
2419
                    cl.clearBuffer(pmeGrid1);
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
                    cl.executeKernel(pmeDispersionAtomRangeKernel, cl.getNumAtoms());
                    setPeriodicBoxSizeArg(cl, pmeDispersionZIndexKernel, 2);
                    if (cl.getUseDoublePrecision())
                        pmeDispersionZIndexKernel.setArg<mm_double4>(3, recipBoxVectors[2]);
                    else
                        pmeDispersionZIndexKernel.setArg<mm_float4>(3, recipBoxVectorsFloat[2]);
                    cl.executeKernel(pmeDispersionZIndexKernel, cl.getNumAtoms());
                    cl.executeKernel(pmeDispersionSpreadChargeKernel, cl.getNumAtoms());
                }
            }
Peter Eastman's avatar
Peter Eastman committed
2430
            dispersionFft->execFFT(pmeGrid1, pmeGrid2, true);
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
            mm_double4 boxSize = cl.getPeriodicBoxSizeDouble();
            if (cl.getUseDoublePrecision()) {
                pmeDispersionConvolutionKernel.setArg<mm_double4>(4, recipBoxVectors[0]);
                pmeDispersionConvolutionKernel.setArg<mm_double4>(5, recipBoxVectors[1]);
                pmeDispersionConvolutionKernel.setArg<mm_double4>(6, recipBoxVectors[2]);
                pmeDispersionEvalEnergyKernel.setArg<mm_double4>(5, recipBoxVectors[0]);
                pmeDispersionEvalEnergyKernel.setArg<mm_double4>(6, recipBoxVectors[1]);
                pmeDispersionEvalEnergyKernel.setArg<mm_double4>(7, recipBoxVectors[2]);
            }
            else {
                pmeDispersionConvolutionKernel.setArg<mm_float4>(4, recipBoxVectorsFloat[0]);
                pmeDispersionConvolutionKernel.setArg<mm_float4>(5, recipBoxVectorsFloat[1]);
                pmeDispersionConvolutionKernel.setArg<mm_float4>(6, recipBoxVectorsFloat[2]);
                pmeDispersionEvalEnergyKernel.setArg<mm_float4>(5, recipBoxVectorsFloat[0]);
                pmeDispersionEvalEnergyKernel.setArg<mm_float4>(6, recipBoxVectorsFloat[1]);
                pmeDispersionEvalEnergyKernel.setArg<mm_float4>(7, recipBoxVectorsFloat[2]);
            }
Andy Simmonett's avatar
Andy Simmonett committed
2448
            if (!hasCoulomb) cl.clearBuffer(pmeEnergyBuffer);
2449
2450
2451
            if (includeEnergy)
                cl.executeKernel(pmeDispersionEvalEnergyKernel, gridSizeX*gridSizeY*gridSizeZ);
            cl.executeKernel(pmeDispersionConvolutionKernel, gridSizeX*gridSizeY*gridSizeZ);
Peter Eastman's avatar
Peter Eastman committed
2452
            dispersionFft->execFFT(pmeGrid2, pmeGrid1, false);
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
            setPeriodicBoxArgs(cl, pmeDispersionInterpolateForceKernel, 3);
            if (cl.getUseDoublePrecision()) {
                pmeDispersionInterpolateForceKernel.setArg<mm_double4>(8, recipBoxVectors[0]);
                pmeDispersionInterpolateForceKernel.setArg<mm_double4>(9, recipBoxVectors[1]);
                pmeDispersionInterpolateForceKernel.setArg<mm_double4>(10, recipBoxVectors[2]);
            }
            else {
                pmeDispersionInterpolateForceKernel.setArg<mm_float4>(8, recipBoxVectorsFloat[0]);
                pmeDispersionInterpolateForceKernel.setArg<mm_float4>(9, recipBoxVectorsFloat[1]);
                pmeDispersionInterpolateForceKernel.setArg<mm_float4>(10, recipBoxVectorsFloat[2]);
            }
            if (deviceIsCpu)
                cl.executeKernel(pmeDispersionInterpolateForceKernel, 2*cl.getDevice().getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>(), 1);
            else
                cl.executeKernel(pmeDispersionInterpolateForceKernel, cl.getNumAtoms());
        }
2469
2470
2471
2472
        if (usePmeQueue) {
            pmeQueue.enqueueMarker(&pmeSyncEvent);
            cl.restoreDefaultQueue();
        }
2473
    }
2474
    if (dispersionCoefficient != 0.0 && includeDirect) {
2475
        mm_double4 boxSize = cl.getPeriodicBoxSizeDouble();
2476
2477
2478
        energy += dispersionCoefficient/(boxSize.x*boxSize.y*boxSize.z);
    }
    return energy;
2479
2480
}

2481
2482
2483
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496
2497
2498
2499
2500
void OpenCLCalcNonbondedForceKernel::copyParametersToContext(ContextImpl& context, const NonbondedForce& force) {
    // Make sure the new parameters are acceptable.
    
    if (force.getNumParticles() != cl.getNumAtoms())
        throw OpenMMException("updateParametersInContext: The number of particles has changed");
    if (!hasCoulomb || !hasLJ) {
        for (int i = 0; i < force.getNumParticles(); i++) {
            double charge, sigma, epsilon;
            force.getParticleParameters(i, charge, sigma, epsilon);
            if (!hasCoulomb && charge != 0.0)
                throw OpenMMException("updateParametersInContext: The nonbonded force kernel does not include Coulomb interactions, because all charges were originally 0");
            if (!hasLJ && epsilon != 0.0)
                throw OpenMMException("updateParametersInContext: The nonbonded force kernel does not include Lennard-Jones interactions, because all epsilons were originally 0");
        }
    }
    vector<int> exceptions;
    for (int i = 0; i < force.getNumExceptions(); i++) {
        int particle1, particle2;
        double chargeProd, sigma, epsilon;
        force.getExceptionParameters(i, particle1, particle2, chargeProd, sigma, epsilon);
2501
        if (exceptionAtoms.size() > exceptions.size() && make_pair(particle1, particle2) == exceptionAtoms[exceptions.size()])
2502
            exceptions.push_back(i);
2503
2504
        else if (chargeProd != 0.0 || epsilon != 0.0)
            throw OpenMMException("updateParametersInContext: The set of non-excluded exceptions has changed");
2505
2506
2507
2508
2509
2510
2511
2512
    }
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*exceptions.size()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*exceptions.size()/numContexts;
    int numExceptions = endIndex-startIndex;
    
    // Record the per-particle parameters.
    
2513
    vector<mm_float4> baseParticleParamVec(cl.getPaddedNumAtoms(), mm_float4(0, 0, 0, 0));
2514
2515
    for (int i = 0; i < force.getNumParticles(); i++) {
        double charge, sigma, epsilon;
2516
        force.getParticleParameters(i, charge, sigma, epsilon);
2517
        baseParticleParamVec[i] = mm_float4(charge, sigma, epsilon, 0);
2518
    }
2519
    baseParticleParams.upload(baseParticleParamVec);
2520
2521
2522
2523
2524
    
    // Record the exceptions.
    
    if (numExceptions > 0) {
        vector<vector<int> > atoms(numExceptions, vector<int>(2));
2525
        vector<mm_float4> baseExceptionParamsVec(numExceptions);
2526
2527
2528
        for (int i = 0; i < numExceptions; i++) {
            double chargeProd, sigma, epsilon;
            force.getExceptionParameters(exceptions[startIndex+i], atoms[i][0], atoms[i][1], chargeProd, sigma, epsilon);
2529
            baseExceptionParamsVec[i] = mm_float4(chargeProd, sigma, epsilon, 0);
2530
        }
2531
        baseExceptionParams.upload(baseExceptionParamsVec);
2532
2533
2534
2535
    }
    
    // Compute other values.
    
2536
2537
    ewaldSelfEnergy = 0.0;
    if (nonbondedMethod == Ewald || nonbondedMethod == PME || nonbondedMethod == LJPME) {
2538
2539
2540
2541
2542
2543
        if (cl.getContextIndex() == 0) {
            for (int i = 0; i < force.getNumParticles(); i++) {
                ewaldSelfEnergy -= baseParticleParamVec[i].x*baseParticleParamVec[i].x*ONE_4PI_EPS0*alpha/sqrt(M_PI);
                if (doLJPME)
                    ewaldSelfEnergy += baseParticleParamVec[i].z*pow(baseParticleParamVec[i].y*dispersionAlpha, 6)/3.0;
            }
2544
2545
        }
    }
2546
    if (force.getUseDispersionCorrection() && cl.getContextIndex() == 0 && (nonbondedMethod == CutoffPeriodic || nonbondedMethod == Ewald || nonbondedMethod == PME))
2547
        dispersionCoefficient = NonbondedForceImpl::calcDispersionCorrection(context.getSystem(), force);
2548
    cl.invalidateMolecules(info);
Peter Eastman's avatar
Peter Eastman committed
2549
    recomputeParams = true;
2550
2551
}

2552
2553
2554
2555
2556
2557
2558
2559
2560
2561
2562
2563
2564
void OpenCLCalcNonbondedForceKernel::getPMEParameters(double& alpha, int& nx, int& ny, int& nz) const {
    if (nonbondedMethod != PME)
        throw OpenMMException("getPMEParametersInContext: This Context is not using PME");
    if (cl.getPlatformData().useCpuPme)
        cpuPme.getAs<CalcPmeReciprocalForceKernel>().getPMEParameters(alpha, nx, ny, nz);
    else {
        alpha = this->alpha;
        nx = gridSizeX;
        ny = gridSizeY;
        nz = gridSizeZ;
    }
}

2565
void OpenCLCalcNonbondedForceKernel::getLJPMEParameters(double& alpha, int& nx, int& ny, int& nz) const {
2566
2567
2568
    if (nonbondedMethod != LJPME)
        throw OpenMMException("getPMEParametersInContext: This Context is not using PME");
    if (cl.getPlatformData().useCpuPme)
2569
2570
        //cpuPme.getAs<CalcPmeReciprocalForceKernel>().getLJPMEParameters(alpha, nx, ny, nz);
        throw OpenMMException("getPMEParametersInContext: CPUPME has not been implemented for LJPME yet.");
2571
    else {
2572
2573
2574
2575
        alpha = this->dispersionAlpha;
        nx = dispersionGridSizeX;
        ny = dispersionGridSizeY;
        nz = dispersionGridSizeZ;
2576
2577
2578
    }
}

2579
class OpenCLCalcCustomNonbondedForceKernel::ForceInfo : public OpenCLForceInfo {
2580
public:
2581
    ForceInfo(int requiredBuffers, const CustomNonbondedForce& force) : OpenCLForceInfo(requiredBuffers), force(force) {
2582
2583
2584
2585
2586
        if (force.getNumInteractionGroups() > 0) {
            groupsForParticle.resize(force.getNumParticles());
            for (int i = 0; i < force.getNumInteractionGroups(); i++) {
                set<int> set1, set2;
                force.getInteractionGroupParameters(i, set1, set2);
peastman's avatar
peastman committed
2587
2588
2589
2590
                for (int p : set1)
                    groupsForParticle[p].insert(2*i);
                for (int p : set2)
                    groupsForParticle[p].insert(2*i+1);
2591
2592
            }
        }
2593
2594
2595
2596
2597
2598
    }
    bool areParticlesIdentical(int particle1, int particle2) {
        vector<double> params1;
        vector<double> params2;
        force.getParticleParameters(particle1, params1);
        force.getParticleParameters(particle2, params2);
2599
        for (int i = 0; i < (int) params1.size(); i++)
2600
2601
            if (params1[i] != params2[i])
                return false;
2602
2603
        if (groupsForParticle.size() > 0 && groupsForParticle[particle1] != groupsForParticle[particle2])
            return false;
2604
2605
2606
        return true;
    }
    int getNumParticleGroups() {
2607
        return force.getNumExclusions();
2608
    }
Peter Eastman's avatar
Peter Eastman committed
2609
    void getParticlesInGroup(int index, vector<int>& particles) {
2610
        int particle1, particle2;
2611
        force.getExclusionParticles(index, particle1, particle2);
2612
2613
2614
2615
2616
2617
2618
2619
2620
        particles.resize(2);
        particles[0] = particle1;
        particles[1] = particle2;
    }
    bool areGroupsIdentical(int group1, int group2) {
        return true;
    }
private:
    const CustomNonbondedForce& force;
2621
    vector<set<int> > groupsForParticle;
2622
2623
2624
2625
2626
};

OpenCLCalcCustomNonbondedForceKernel::~OpenCLCalcCustomNonbondedForceKernel() {
    if (params != NULL)
        delete params;
2627
2628
    if (forceCopy != NULL)
        delete forceCopy;
2629
2630
2631
2632
2633
2634
}

void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, const CustomNonbondedForce& force) {
    int forceIndex;
    for (forceIndex = 0; forceIndex < system.getNumForces() && &system.getForce(forceIndex) != &force; ++forceIndex)
        ;
2635
    string prefix = (force.getNumInteractionGroups() == 0 ? "custom"+cl.intToString(forceIndex)+"_" : "");
2636
2637
2638
2639

    // Record parameters and exclusions.

    int numParticles = force.getNumParticles();
2640
    params = new OpenCLParameterSet(cl, force.getNumPerParticleParameters(), numParticles, "customNonbondedParameters");
2641
    if (force.getNumGlobalParameters() > 0)
peastman's avatar
peastman committed
2642
        globals.initialize<cl_float>(cl, force.getNumGlobalParameters(), "customNonbondedGlobals", CL_MEM_READ_ONLY);
2643
    vector<vector<cl_float> > paramVector(numParticles);
2644
2645
2646
2647
    vector<vector<int> > exclusionList(numParticles);
    for (int i = 0; i < numParticles; i++) {
        vector<double> parameters;
        force.getParticleParameters(i, parameters);
2648
        paramVector[i].resize(parameters.size());
2649
        for (int j = 0; j < (int) parameters.size(); j++)
2650
            paramVector[i][j] = (cl_float) parameters[j];
2651
2652
        exclusionList[i].push_back(i);
    }
2653
2654
2655
2656
2657
    for (int i = 0; i < force.getNumExclusions(); i++) {
        int particle1, particle2;
        force.getExclusionParticles(i, particle1, particle2);
        exclusionList[particle1].push_back(particle2);
        exclusionList[particle2].push_back(particle1);
2658
    }
2659
    params->setParameterValues(paramVector);
2660
2661
2662

    // Record the tabulated functions.

2663
2664
    map<string, Lepton::CustomFunction*> functions;
    vector<pair<string, string> > functionDefinitions;
2665
    vector<const TabulatedFunction*> functionList;
2666
    vector<string> tableTypes;
peastman's avatar
peastman committed
2667
2668
    tabulatedFunctions.resize(force.getNumTabulatedFunctions());
    for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
2669
2670
        functionList.push_back(&force.getTabulatedFunction(i));
        string name = force.getTabulatedFunctionName(i);
2671
        string arrayName = prefix+"table"+cl.intToString(i);
2672
        functionDefinitions.push_back(make_pair(name, arrayName));
2673
        functions[name] = cl.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
peastman's avatar
peastman committed
2674
        int width;
2675
        vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
peastman's avatar
peastman committed
2676
2677
2678
        tabulatedFunctions[i].initialize<float>(cl, f.size(), "TabulatedFunction");
        tabulatedFunctions[i].upload(f);
        cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(arrayName, "float", width, width*sizeof(float), tabulatedFunctions[i].getDeviceBuffer()));
2679
2680
2681
2682
        if (width == 1)
            tableTypes.push_back("float");
        else
            tableTypes.push_back("float"+cl.intToString(width));
2683
2684
2685
2686
2687
2688
2689
2690
2691
2692
    }

    // Record information for the expressions.

    globalParamNames.resize(force.getNumGlobalParameters());
    globalParamValues.resize(force.getNumGlobalParameters());
    for (int i = 0; i < force.getNumGlobalParameters(); i++) {
        globalParamNames[i] = force.getGlobalParameterName(i);
        globalParamValues[i] = (cl_float) force.getGlobalParameterDefaultValue(i);
    }
peastman's avatar
peastman committed
2693
2694
    if (globals.isInitialized())
        globals.upload(globalParamValues);
2695
2696
    bool useCutoff = (force.getNonbondedMethod() != CustomNonbondedForce::NoCutoff);
    bool usePeriodic = (force.getNonbondedMethod() != CustomNonbondedForce::NoCutoff && force.getNonbondedMethod() != CustomNonbondedForce::CutoffNonPeriodic);
2697
    Lepton::ParsedExpression energyExpression = Lepton::Parser::parse(force.getEnergyFunction(), functions).optimize();
2698
    Lepton::ParsedExpression forceExpression = energyExpression.differentiate("r").optimize();
2699
    map<string, Lepton::ParsedExpression> forceExpressions;
2700
    forceExpressions["real customEnergy = "] = energyExpression;
2701
    forceExpressions["tempForce -= "] = forceExpression;
2702
2703
2704

    // Create the kernels.

2705
2706
2707
2708
2709
    vector<pair<ExpressionTreeNode, string> > variables;
    ExpressionTreeNode rnode(new Operation::Variable("r"));
    variables.push_back(make_pair(rnode, "r"));
    variables.push_back(make_pair(ExpressionTreeNode(new Operation::Square(), rnode), "r2"));
    variables.push_back(make_pair(ExpressionTreeNode(new Operation::Reciprocal(), rnode), "invR"));
2710
2711
    for (int i = 0; i < force.getNumPerParticleParameters(); i++) {
        const string& name = force.getPerParticleParameterName(i);
2712
2713
        variables.push_back(makeVariable(name+"1", prefix+"params"+params->getParameterSuffix(i, "1")));
        variables.push_back(makeVariable(name+"2", prefix+"params"+params->getParameterSuffix(i, "2")));
2714
2715
2716
    }
    for (int i = 0; i < force.getNumGlobalParameters(); i++) {
        const string& name = force.getGlobalParameterName(i);
2717
        string value = "globals["+cl.intToString(i)+"]";
2718
        variables.push_back(makeVariable(name, prefix+value));
2719
    }
2720
2721
2722
2723
2724
2725
    for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
        string paramName = force.getEnergyParameterDerivativeName(i);
        string derivVariable = cl.getNonbondedUtilities().addEnergyParameterDerivative(paramName);
        Lepton::ParsedExpression derivExpression = energyExpression.differentiate(paramName).optimize();
        forceExpressions[derivVariable+" += interactionScale*switchValue*"] = derivExpression;
    }
2726
    stringstream compute;
2727
    compute << cl.getExpressionUtilities().createExpressions(forceExpressions, variables, functionList, functionDefinitions, prefix+"temp");
2728
2729
    map<string, string> replacements;
    replacements["COMPUTE_FORCE"] = compute.str();
2730
2731
2732
2733
2734
2735
2736
2737
2738
    replacements["USE_SWITCH"] = (useCutoff && force.getUseSwitchingFunction() ? "1" : "0");
    if (force.getUseSwitchingFunction()) {
        // Compute the switching coefficients.
        
        replacements["SWITCH_CUTOFF"] = cl.doubleToString(force.getSwitchingDistance());
        replacements["SWITCH_C3"] = cl.doubleToString(10/pow(force.getSwitchingDistance()-force.getCutoffDistance(), 3.0));
        replacements["SWITCH_C4"] = cl.doubleToString(15/pow(force.getSwitchingDistance()-force.getCutoffDistance(), 4.0));
        replacements["SWITCH_C5"] = cl.doubleToString(6/pow(force.getSwitchingDistance()-force.getCutoffDistance(), 5.0));
    }
2739
    string source = cl.replaceStrings(OpenCLKernelSources::customNonbonded, replacements);
2740
    if (force.getNumInteractionGroups() > 0)
2741
        initInteractionGroups(force, source, tableTypes);
2742
2743
2744
2745
2746
2747
    else {
        cl.getNonbondedUtilities().addInteraction(useCutoff, usePeriodic, true, force.getCutoffDistance(), exclusionList, source, force.getForceGroup());
        for (int i = 0; i < (int) params->getBuffers().size(); i++) {
            const OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i];
            cl.getNonbondedUtilities().addParameter(OpenCLNonbondedUtilities::ParameterInfo(prefix+"params"+cl.intToString(i+1), buffer.getComponentType(), buffer.getNumComponents(), buffer.getSize(), buffer.getMemory()));
        }
peastman's avatar
peastman committed
2748
2749
2750
        if (globals.isInitialized()) {
            globals.upload(globalParamValues);
            cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(prefix+"globals", "float", 1, sizeof(cl_float), globals.getDeviceBuffer()));
2751
        }
2752
    }
2753
2754
    info = new ForceInfo(cl.getNonbondedUtilities().getNumForceBuffers(), force);
    cl.addForce(info);
2755
2756
2757
2758
2759
2760
2761
2762
2763
2764
2765
    
    // Record information for the long range correction.
    
    if (force.getNonbondedMethod() == CustomNonbondedForce::CutoffPeriodic && force.getUseLongRangeCorrection() && cl.getContextIndex() == 0) {
        forceCopy = new CustomNonbondedForce(force);
        hasInitializedLongRangeCorrection = false;
    }
    else {
        longRangeCoefficient = 0.0;
        hasInitializedLongRangeCorrection = true;
    }
2766
2767
}

2768
void OpenCLCalcCustomNonbondedForceKernel::initInteractionGroups(const CustomNonbondedForce& force, const string& interactionSource, const vector<string>& tableTypes) {
2769
2770
2771
2772
    // Process groups to form tiles.
    
    vector<vector<int> > atomLists;
    vector<pair<int, int> > tiles;
2773
2774
    vector<int> tileGroup;
    vector<vector<int> > duplicateAtomsForGroup;
2775
2776
2777
2778
2779
2780
2781
2782
2783
2784
    for (int group = 0; group < force.getNumInteractionGroups(); group++) {
        // Get the list of atoms in this group and sort them.
        
        set<int> set1, set2;
        force.getInteractionGroupParameters(group, set1, set2);
        vector<int> atoms1, atoms2;
        atoms1.insert(atoms1.begin(), set1.begin(), set1.end());
        atoms2.insert(atoms2.begin(), set2.begin(), set2.end());
        sort(atoms1.begin(), atoms1.end());
        sort(atoms2.begin(), atoms2.end());
2785
2786
2787
2788
        duplicateAtomsForGroup.push_back(vector<int>());
        set_intersection(set1.begin(), set1.end(), set2.begin(), set2.end(),
                inserter(duplicateAtomsForGroup[group], duplicateAtomsForGroup[group].begin()));
        sort(duplicateAtomsForGroup[group].begin(), duplicateAtomsForGroup[group].end());
2789
2790
2791
2792
        
        // Find how many tiles we will create for this group.
        
        int tileWidth = min(min(32, (int) atoms1.size()), (int) atoms2.size());
2793
2794
        if (tileWidth == 0)
            continue;
2795
2796
2797
2798
2799
        int numBlocks1 = (atoms1.size()+tileWidth-1)/tileWidth;
        int numBlocks2 = (atoms2.size()+tileWidth-1)/tileWidth;
        
        // Add the tiles.
        
2800
        int firstTile = tiles.size();
2801
        for (int i = 0; i < numBlocks1; i++)
2802
            for (int j = 0; j < numBlocks2; j++) {
2803
                tiles.push_back(make_pair(atomLists.size()+i, atomLists.size()+numBlocks1+j));
2804
2805
                tileGroup.push_back(group);
            }
2806
2807
2808
2809
2810
2811
2812
2813
2814
2815
2816
2817
2818
2819
2820
2821
2822
2823
2824
2825
2826
2827
2828
        
        // Add the atom lists.
        
        for (int i = 0; i < numBlocks1; i++) {
            vector<int> atoms;
            int first = i*tileWidth;
            int last = min((i+1)*tileWidth, (int) atoms1.size());
            for (int j = first; j < last; j++)
                atoms.push_back(atoms1[j]);
            atomLists.push_back(atoms);
        }
        for (int i = 0; i < numBlocks2; i++) {
            vector<int> atoms;
            int first = i*tileWidth;
            int last = min((i+1)*tileWidth, (int) atoms2.size());
            for (int j = first; j < last; j++)
                atoms.push_back(atoms2[j]);
            atomLists.push_back(atoms);
        }
    }
    
    // Build a lookup table for quickly identifying excluded interactions.
    
2829
    vector<set<int> > exclusions(force.getNumParticles());
2830
2831
2832
    for (int i = 0; i < force.getNumExclusions(); i++) {
        int p1, p2;
        force.getExclusionParticles(i, p1, p2);
2833
2834
        exclusions[p1].insert(p2);
        exclusions[p2].insert(p1);
2835
2836
2837
2838
2839
2840
2841
2842
    }
    
    // Build the exclusion flags for each tile.  While we're at it, filter out tiles
    // where all interactions are excluded, and sort the tiles by size.

    vector<vector<int> > exclusionFlags(tiles.size());
    vector<pair<int, int> > tileOrder;
    for (int tile = 0; tile < tiles.size(); tile++) {
2843
        bool swapped = false;
2844
2845
2846
2847
2848
2849
        if (atomLists[tiles[tile].first].size() < atomLists[tiles[tile].second].size()) {
            // For efficiency, we want the first axis to be the larger one.
            
            int swap = tiles[tile].first;
            tiles[tile].first = tiles[tile].second;
            tiles[tile].second = swap;
2850
            swapped = true;
2851
2852
2853
        }
        vector<int>& atoms1 = atomLists[tiles[tile].first];
        vector<int>& atoms2 = atomLists[tiles[tile].second];
2854
        vector<int>& duplicateAtoms = duplicateAtomsForGroup[tileGroup[tile]];
2855
2856
        vector<int>& flags = exclusionFlags[tile];
        flags.resize(atoms1.size(), (int) (1LL<<atoms2.size())-1);
2857
        int numExcluded = 0;
2858
2859
2860
        for (int i = 0; i < (int) atoms1.size(); i++) {
            int a1 = atoms1[i];
            bool a1IsDuplicate = binary_search(duplicateAtoms.begin(), duplicateAtoms.end(), a1);
2861
2862
            for (int j = 0; j < (int) atoms2.size(); j++) {
                int a2 = atoms2[j];
peastman's avatar
peastman committed
2863
                bool isExcluded = false;
2864
                if (a1 == a2 || exclusions[a1].find(a2) != exclusions[a1].end())
peastman's avatar
peastman committed
2865
                    isExcluded = true; // This is an excluded interaction.
2866
2867
                else if ((a1 > a2) == swapped && a1IsDuplicate && binary_search(duplicateAtoms.begin(), duplicateAtoms.end(), a2))
                    isExcluded = true; // Both atoms are in both sets, so skip duplicate interactions.
peastman's avatar
peastman committed
2868
                if (isExcluded) {
2869
2870
2871
2872
                    flags[i] &= -1-(1<<j);
                    numExcluded++;
                }
            }
2873
        }
2874
2875
2876
2877
2878
2879
2880
2881
2882
2883
2884
2885
2886
2887
2888
2889
2890
2891
2892
2893
2894
2895
2896
2897
2898
2899
2900
2901
        if (numExcluded == atoms1.size()*atoms2.size())
            continue; // All interactions are excluded.
        tileOrder.push_back(make_pair((int) -atoms2.size(), tile));
    }
    sort(tileOrder.begin(), tileOrder.end());
    
    // Merge tiles to get as close as possible to 32 along the first axis of each one.
    
    vector<int> tileSetStart;
    tileSetStart.push_back(0);
    int tileSetSize = 0;
    for (int i = 0; i < tileOrder.size(); i++) {
        int tile = tileOrder[i].second;
        int size = atomLists[tiles[tile].first].size();
        if (tileSetSize+size > 32) {
            tileSetStart.push_back(i);
            tileSetSize = 0;
        }
        tileSetSize += size;
    }
    tileSetStart.push_back(tileOrder.size());
    
    // Build the data structures.
    
    int numTileSets = tileSetStart.size()-1;
    vector<mm_int4> groupData;
    for (int tileSet = 0; tileSet < numTileSets; tileSet++) {
        int indexInTileSet = 0;
2902
2903
2904
2905
2906
2907
2908
2909
        int minSize = 0;
        if (cl.getSIMDWidth() < 32) {
            // We need to include a barrier inside the inner loop, so ensure that all
            // threads will loop the same number of times.
            
            for (int i = tileSetStart[tileSet]; i < tileSetStart[tileSet+1]; i++)
                minSize = max(minSize, (int) atomLists[tiles[tileOrder[i].second].first].size());
        }
2910
2911
2912
2913
        for (int i = tileSetStart[tileSet]; i < tileSetStart[tileSet+1]; i++) {
            int tile = tileOrder[i].second;
            vector<int>& atoms1 = atomLists[tiles[tile].first];
            vector<int>& atoms2 = atomLists[tiles[tile].second];
2914
            int range = indexInTileSet + ((indexInTileSet+max(minSize, (int) atoms1.size()))<<16);
2915
2916
2917
2918
2919
2920
2921
2922
2923
2924
            int allFlags = (1<<atoms2.size())-1;
            for (int j = 0; j < (int) atoms1.size(); j++) {
                int a1 = atoms1[j];
                int a2 = (j < atoms2.size() ? atoms2[j] : 0);
                int flags = (exclusionFlags[tile].size() > 0 ? exclusionFlags[tile][j] : allFlags);
                groupData.push_back(mm_int4(a1, a2, range, flags<<indexInTileSet));
            }
            indexInTileSet += atoms1.size();
        }
        for (; indexInTileSet < 32; indexInTileSet++)
2925
            groupData.push_back(mm_int4(0, 0, minSize<<16, 0));
2926
    }
peastman's avatar
peastman committed
2927
2928
    interactionGroupData.initialize<mm_int4>(cl, groupData.size(), "interactionGroupData");
    interactionGroupData.upload(groupData);
2929
2930
2931
2932
2933
2934
2935
2936
2937
2938
    numGroupTiles.initialize<cl_int>(cl, 1, "numGroupTiles");

    // Allocate space for a neighbor list, if necessary.

    if (force.getNonbondedMethod() != CustomNonbondedForce::NoCutoff && groupData.size() > cl.getNumThreadBlocks()) {
        filteredGroupData.initialize<mm_int4>(cl, groupData.size(), "filteredGroupData");
        interactionGroupData.copyTo(filteredGroupData);
        int numTiles = groupData.size()/32;
        numGroupTiles.upload(&numTiles);
    }
2939
2940
2941
    
    // Create the kernel.
    
2942
    hasParamDerivs = (force.getNumEnergyParameterDerivatives() > 0);
2943
2944
2945
2946
2947
2948
2949
2950
2951
2952
2953
2954
2955
2956
2957
2958
2959
2960
2961
    map<string, string> replacements;
    replacements["COMPUTE_INTERACTION"] = interactionSource;
    const string suffixes[] = {"x", "y", "z", "w"};
    stringstream localData;
    int localDataSize = 0;
    vector<OpenCLNonbondedUtilities::ParameterInfo>& buffers = params->getBuffers(); 
    for (int i = 0; i < (int) buffers.size(); i++) {
        if (buffers[i].getNumComponents() == 1)
            localData<<buffers[i].getComponentType()<<" params"<<(i+1)<<";\n";
        else {
            for (int j = 0; j < buffers[i].getNumComponents(); ++j)
                localData<<buffers[i].getComponentType()<<" params"<<(i+1)<<"_"<<suffixes[j]<<";\n";
        }
        localDataSize += buffers[i].getSize();
    }
    replacements["ATOM_PARAMETER_DATA"] = localData.str();
    stringstream args;
    for (int i = 0; i < (int) buffers.size(); i++)
        args<<", __global const "<<buffers[i].getType()<<"* restrict global_params"<<(i+1);
2962
2963
    for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
        args << ", __global const " << tableTypes[i]<< "* restrict table" << i;
peastman's avatar
peastman committed
2964
    if (globals.isInitialized())
2965
        args<<", __global const float* restrict globals";
2966
2967
    if (hasParamDerivs)
        args << ", __global mixed* restrict energyParamDerivs";
2968
2969
2970
2971
2972
2973
2974
2975
2976
2977
2978
2979
2980
2981
2982
2983
2984
2985
2986
2987
2988
    replacements["PARAMETER_ARGUMENTS"] = args.str();
    stringstream load1;
    for (int i = 0; i < (int) buffers.size(); i++)
        load1<<buffers[i].getType()<<" params"<<(i+1)<<"1 = global_params"<<(i+1)<<"[atom1];\n";
    replacements["LOAD_ATOM1_PARAMETERS"] = load1.str();
    stringstream loadLocal2;
    for (int i = 0; i < (int) buffers.size(); i++) {
        if (buffers[i].getNumComponents() == 1)
            loadLocal2<<"localData[get_local_id(0)].params"<<(i+1)<<" = global_params"<<(i+1)<<"[atom2];\n";
        else {
            loadLocal2<<buffers[i].getType()<<" temp_params"<<(i+1)<<" = global_params"<<(i+1)<<"[atom2];\n";
            for (int j = 0; j < buffers[i].getNumComponents(); ++j)
                loadLocal2<<"localData[get_local_id(0)].params"<<(i+1)<<"_"<<suffixes[j]<<" = temp_params"<<(i+1)<<"."<<suffixes[j]<<";\n";
        }
    }
    replacements["LOAD_LOCAL_PARAMETERS"] = loadLocal2.str();
    stringstream load2;
    for (int i = 0; i < (int) buffers.size(); i++) {
        if (buffers[i].getNumComponents() == 1)
            load2<<buffers[i].getType()<<" params"<<(i+1)<<"2 = localData[localIndex].params"<<(i+1)<<";\n";
        else {
2989
            load2<<buffers[i].getType()<<" params"<<(i+1)<<"2 = ("<<buffers[i].getType()<<") (";
2990
2991
2992
2993
2994
2995
2996
2997
2998
            for (int j = 0; j < buffers[i].getNumComponents(); ++j) {
                if (j > 0)
                    load2<<", ";
                load2<<"localData[localIndex].params"<<(i+1)<<"_"<<suffixes[j];
            }
            load2<<");\n";
        }
    }
    replacements["LOAD_ATOM2_PARAMETERS"] = load2.str();
2999
3000
3001
3002
3003
3004
3005
3006
3007
3008
3009
3010
3011
    stringstream initDerivs, saveDerivs;
    const vector<string>& allParamDerivNames = cl.getEnergyParamDerivNames();
    int numDerivs = allParamDerivNames.size();
    for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
        string paramName = force.getEnergyParameterDerivativeName(i);
        string derivVariable = cl.getNonbondedUtilities().addEnergyParameterDerivative(paramName);
        initDerivs<<"mixed "<<derivVariable<<" = 0;\n";
        for (int index = 0; index < numDerivs; index++)
            if (allParamDerivNames[index] == paramName)
                saveDerivs<<"energyParamDerivs[get_global_id(0)*"<<numDerivs<<"+"<<index<<"] += "<<derivVariable<<";\n";
    }
    replacements["INIT_DERIVATIVES"] = initDerivs.str();
    replacements["SAVE_DERIVATIVES"] = saveDerivs.str();
3012
3013
3014
3015
3016
    map<string, string> defines;
    if (force.getNonbondedMethod() != CustomNonbondedForce::NoCutoff)
        defines["USE_CUTOFF"] = "1";
    if (force.getNonbondedMethod() == CustomNonbondedForce::CutoffPeriodic)
        defines["USE_PERIODIC"] = "1";
3017
3018
3019
    int localMemorySize = max(32, cl.getNonbondedUtilities().getForceThreadBlockSize());
    defines["LOCAL_MEMORY_SIZE"] = cl.intToString(localMemorySize);
    defines["WARPS_IN_BLOCK"] = cl.intToString(localMemorySize/32);
3020
3021
    double cutoff = force.getCutoffDistance();
    defines["CUTOFF_SQUARED"] = cl.doubleToString(cutoff*cutoff);
3022
3023
    double paddedCutoff = cl.getNonbondedUtilities().padCutoff(cutoff);
    defines["PADDED_CUTOFF_SQUARED"] = cl.doubleToString(paddedCutoff*paddedCutoff);
3024
3025
    defines["PADDED_NUM_ATOMS"] = cl.intToString(cl.getPaddedNumAtoms());
    defines["TILE_SIZE"] = "32";
3026
    defines["NUM_TILES"] = cl.intToString(numTileSets);
3027
3028
3029
3030
3031
3032
3033
3034
3035
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*numTileSets/numContexts;
    int endIndex = (cl.getContextIndex()+1)*numTileSets/numContexts;
    defines["FIRST_TILE"] = cl.intToString(startIndex);
    defines["LAST_TILE"] = cl.intToString(endIndex);
    if ((localDataSize/4)%2 == 0 && !cl.getUseDoublePrecision())
        defines["PARAMETER_SIZE_IS_EVEN"] = "1";
    cl::Program program = cl.createProgram(cl.replaceStrings(OpenCLKernelSources::customNonbondedGroups, replacements), defines);
    interactionGroupKernel = cl::Kernel(program, "computeInteractionGroups");
3036
3037
    prepareNeighborListKernel = cl::Kernel(program, "prepareToBuildNeighborList");
    buildNeighborListKernel = cl::Kernel(program, "buildNeighborList");
3038
3039
3040
    numGroupThreadBlocks = cl.getNonbondedUtilities().getNumForceThreadBlocks();
}

3041
double OpenCLCalcCustomNonbondedForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
3042
3043
3044
3045
3046
    useNeighborList = (filteredGroupData.isInitialized() && cl.getNonbondedUtilities().getUseCutoff());
    if (useNeighborList && cl.getContextIndex() > 0) {
        // When using a neighbor list, run the whole calculation on a single device.
        return 0.0;
    }
peastman's avatar
peastman committed
3047
    if (globals.isInitialized()) {
3048
        bool changed = false;
3049
        for (int i = 0; i < (int) globalParamNames.size(); i++) {
3050
3051
3052
3053
3054
            cl_float value = (cl_float) context.getParameter(globalParamNames[i]);
            if (value != globalParamValues[i])
                changed = true;
            globalParamValues[i] = value;
        }
3055
        if (changed) {
peastman's avatar
peastman committed
3056
            globals.upload(globalParamValues);
3057
            if (forceCopy != NULL) {
3058
                CustomNonbondedForceImpl::calcLongRangeCorrection(*forceCopy, context.getOwner(), longRangeCoefficient, longRangeCoefficientDerivs);
3059
3060
3061
                hasInitializedLongRangeCorrection = true;
            }
        }
3062
    }
3063
    if (!hasInitializedLongRangeCorrection) {
3064
        CustomNonbondedForceImpl::calcLongRangeCorrection(*forceCopy, context.getOwner(), longRangeCoefficient, longRangeCoefficientDerivs);
3065
3066
        hasInitializedLongRangeCorrection = true;
    }
peastman's avatar
peastman committed
3067
    if (interactionGroupData.isInitialized()) {
3068
3069
3070
        if (!hasInitializedKernel) {
            hasInitializedKernel = true;
            int index = 0;
3071
3072
            bool useLong = cl.getSupports64BitGlobalAtomics();
            interactionGroupKernel.setArg<cl::Buffer>(index++, (useLong ? cl.getLongForceBuffer() : cl.getForceBuffers()).getDeviceBuffer());
3073
3074
            interactionGroupKernel.setArg<cl::Buffer>(index++, cl.getEnergyBuffer().getDeviceBuffer());
            interactionGroupKernel.setArg<cl::Buffer>(index++, cl.getPosq().getDeviceBuffer());
3075
3076
            interactionGroupKernel.setArg<cl::Buffer>(index++, (useNeighborList ? filteredGroupData : interactionGroupData).getDeviceBuffer());
            interactionGroupKernel.setArg<cl::Buffer>(index++, numGroupTiles.getDeviceBuffer());
3077
            interactionGroupKernel.setArg<cl_int>(index++, useNeighborList);
3078
            index += 5;
peastman's avatar
peastman committed
3079
3080
            for (auto& buffer : params->getBuffers())
                interactionGroupKernel.setArg<cl::Memory>(index++, buffer.getMemory());
peastman's avatar
peastman committed
3081
3082
3083
3084
            for (auto& function : tabulatedFunctions)
                interactionGroupKernel.setArg<cl::Memory>(index++, function.getDeviceBuffer());
            if (globals.isInitialized())
                interactionGroupKernel.setArg<cl::Buffer>(index++, globals.getDeviceBuffer());
3085
3086
            if (hasParamDerivs)
                interactionGroupKernel.setArg<cl::Memory>(index++, cl.getEnergyParamDerivBuffer().getDeviceBuffer());
3087
3088
3089
3090
3091
3092
3093
3094
3095
3096
3097
            if (useNeighborList) {
                // Initialize kernels for building the interaction group neighbor list.
                
                prepareNeighborListKernel.setArg<cl::Buffer>(0, cl.getNonbondedUtilities().getRebuildNeighborList().getDeviceBuffer());
                prepareNeighborListKernel.setArg<cl::Buffer>(1, numGroupTiles.getDeviceBuffer());
                buildNeighborListKernel.setArg<cl::Buffer>(0, cl.getNonbondedUtilities().getRebuildNeighborList().getDeviceBuffer());
                buildNeighborListKernel.setArg<cl::Buffer>(1, numGroupTiles.getDeviceBuffer());
                buildNeighborListKernel.setArg<cl::Buffer>(2, cl.getPosq().getDeviceBuffer());
                buildNeighborListKernel.setArg<cl::Buffer>(3, interactionGroupData.getDeviceBuffer());
                buildNeighborListKernel.setArg<cl::Buffer>(4, filteredGroupData.getDeviceBuffer());
            }
3098
        }
3099
        int forceThreadBlockSize = max(32, cl.getNonbondedUtilities().getForceThreadBlockSize());
3100
3101
3102
3103
3104
3105
3106
3107
        if (useNeighborList) {
            // Rebuild the neighbor list, if necessary.

            setPeriodicBoxArgs(cl, buildNeighborListKernel, 5);
            cl.executeKernel(prepareNeighborListKernel, 1, 1);
            cl.executeKernel(buildNeighborListKernel, numGroupThreadBlocks*forceThreadBlockSize, forceThreadBlockSize);
        }
        setPeriodicBoxArgs(cl, interactionGroupKernel, 6);
3108
3109
        cl.executeKernel(interactionGroupKernel, numGroupThreadBlocks*forceThreadBlockSize, forceThreadBlockSize);
    }
3110
    mm_double4 boxSize = cl.getPeriodicBoxSizeDouble();
3111
3112
3113
3114
3115
    double volume = boxSize.x*boxSize.y*boxSize.z;
    map<string, double>& derivs = cl.getEnergyParamDerivWorkspace();
    for (int i = 0; i < longRangeCoefficientDerivs.size(); i++)
        derivs[forceCopy->getEnergyParameterDerivativeName(i)] += longRangeCoefficientDerivs[i]/volume;
    return longRangeCoefficient/volume;
3116
}
Peter Eastman's avatar
Peter Eastman committed
3117

3118
3119
3120
3121
3122
3123
3124
3125
3126
3127
3128
3129
3130
3131
3132
3133
3134
void OpenCLCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImpl& context, const CustomNonbondedForce& force) {
    int numParticles = force.getNumParticles();
    if (numParticles != cl.getNumAtoms())
        throw OpenMMException("updateParametersInContext: The number of particles has changed");
    
    // Record the per-particle parameters.
    
    vector<vector<cl_float> > paramVector(numParticles);
    vector<double> parameters;
    for (int i = 0; i < numParticles; i++) {
        force.getParticleParameters(i, parameters);
        paramVector[i].resize(parameters.size());
        for (int j = 0; j < (int) parameters.size(); j++)
            paramVector[i][j] = (cl_float) parameters[j];
    }
    params->setParameterValues(paramVector);
    
3135
3136
3137
    // If necessary, recompute the long range correction.
    
    if (forceCopy != NULL) {
3138
        CustomNonbondedForceImpl::calcLongRangeCorrection(force, context.getOwner(), longRangeCoefficient, longRangeCoefficientDerivs);
3139
3140
3141
3142
        hasInitializedLongRangeCorrection = true;
        *forceCopy = force;
    }
    
3143
3144
    // Mark that the current reordering may be invalid.
    
3145
    cl.invalidateMolecules(info);
3146
3147
}

3148
class OpenCLCalcGBSAOBCForceKernel::ForceInfo : public OpenCLForceInfo {
Peter Eastman's avatar
Peter Eastman committed
3149
public:
3150
    ForceInfo(int requiredBuffers, const GBSAOBCForce& force) : OpenCLForceInfo(requiredBuffers), force(force) {
Peter Eastman's avatar
Peter Eastman committed
3151
3152
3153
3154
3155
3156
3157
3158
3159
3160
3161
    }
    bool areParticlesIdentical(int particle1, int particle2) {
        double charge1, charge2, radius1, radius2, scale1, scale2;
        force.getParticleParameters(particle1, charge1, radius1, scale1);
        force.getParticleParameters(particle2, charge2, radius2, scale2);
        return (charge1 == charge2 && radius1 == radius2 && scale1 == scale2);
    }
private:
    const GBSAOBCForce& force;
};

3162
void OpenCLCalcGBSAOBCForceKernel::initialize(const System& system, const GBSAOBCForce& force) {
3163
3164
    if (cl.getPlatformData().contexts.size() > 1)
        throw OpenMMException("GBSAOBCForce does not support using multiple OpenCL devices");
3165
3166
3167
3168
    int forceIndex;
    for (forceIndex = 0; forceIndex < system.getNumForces() && &system.getForce(forceIndex) != &force; ++forceIndex)
        ;
    string prefix = "obc"+cl.intToString(forceIndex)+"_";
3169
    OpenCLNonbondedUtilities& nb = cl.getNonbondedUtilities();
peastman's avatar
peastman committed
3170
    params.initialize<mm_float2>(cl, cl.getPaddedNumAtoms(), "gbsaObcParams");
3171
    int elementSize = (cl.getUseDoublePrecision() ? sizeof(cl_double) : sizeof(cl_float));
3172
    charges.initialize(cl, cl.getPaddedNumAtoms(), elementSize, "gbsaObcCharges");
peastman's avatar
peastman committed
3173
3174
    bornRadii.initialize(cl, cl.getPaddedNumAtoms(), elementSize, "bornRadii");
    obcChain.initialize(cl, cl.getPaddedNumAtoms(), elementSize, "obcChain");
3175
    if (cl.getSupports64BitGlobalAtomics()) {
peastman's avatar
peastman committed
3176
3177
3178
3179
3180
        longBornSum.initialize<cl_long>(cl, cl.getPaddedNumAtoms(), "longBornSum");
        longBornForce.initialize<cl_long>(cl, cl.getPaddedNumAtoms(), "longBornForce");
        bornForce.initialize(cl, cl.getPaddedNumAtoms(), elementSize, "bornForce");
        cl.addAutoclearBuffer(longBornSum);
        cl.addAutoclearBuffer(longBornForce);
3181
3182
    }
    else {
peastman's avatar
peastman committed
3183
3184
3185
3186
        bornSum.initialize(cl, cl.getPaddedNumAtoms()*nb.getNumForceBuffers(), elementSize, "bornSum");
        bornForce.initialize(cl, cl.getPaddedNumAtoms()*nb.getNumForceBuffers(), elementSize, "bornForce");
        cl.addAutoclearBuffer(bornSum);
        cl.addAutoclearBuffer(bornForce);
3187
    }
3188
    vector<double> chargeVec(cl.getPaddedNumAtoms());
3189
    vector<mm_float2> paramsVector(cl.getPaddedNumAtoms(), mm_float2(1,1));
3190
    const double dielectricOffset = 0.009;
3191
    for (int i = 0; i < force.getNumParticles(); i++) {
3192
3193
3194
        double charge, radius, scalingFactor;
        force.getParticleParameters(i, charge, radius, scalingFactor);
        radius -= dielectricOffset;
3195
        chargeVec[i] = charge;
3196
        paramsVector[i] = mm_float2((float) radius, (float) (scalingFactor*radius));
3197
    }
peastman's avatar
peastman committed
3198
    charges.upload(chargeVec, true, true);
peastman's avatar
peastman committed
3199
    params.upload(paramsVector);
3200
    prefactor = -ONE_4PI_EPS0*((1.0/force.getSoluteDielectric())-(1.0/force.getSolventDielectric()));
3201
    surfaceAreaFactor = -6.0*4*M_PI*force.getSurfaceAreaEnergy();
3202
3203
    bool useCutoff = (force.getNonbondedMethod() != GBSAOBCForce::NoCutoff);
    bool usePeriodic = (force.getNonbondedMethod() != GBSAOBCForce::NoCutoff && force.getNonbondedMethod() != GBSAOBCForce::CutoffNonPeriodic);
3204
    cutoff = force.getCutoffDistance();
3205
    string source = OpenCLKernelSources::gbsaObc2;
3206
3207
3208
3209
3210
3211
3212
3213
    map<string, string> replacements;
    replacements["CHARGE1"] = prefix+"charge1";
    replacements["CHARGE2"] = prefix+"charge2";
    replacements["OBC_PARAMS1"] = prefix+"obcParams1";
    replacements["OBC_PARAMS2"] = prefix+"obcParams2";
    replacements["BORN_FORCE1"] = prefix+"bornForce1";
    replacements["BORN_FORCE2"] = prefix+"bornForce2";
    source = cl.replaceStrings(source, replacements);
3214
    nb.addInteraction(useCutoff, usePeriodic, false, cutoff, vector<vector<int> >(), source, force.getForceGroup());
3215
3216
3217
    nb.addParameter(OpenCLNonbondedUtilities::ParameterInfo(prefix+"charge", "float", 1, sizeof(cl_float), charges.getDeviceBuffer()));;
    nb.addParameter(OpenCLNonbondedUtilities::ParameterInfo(prefix+"obcParams", "float", 2, sizeof(cl_float2), params.getDeviceBuffer()));;
    nb.addParameter(OpenCLNonbondedUtilities::ParameterInfo(prefix+"bornForce", "real", 1, elementSize, bornForce.getDeviceBuffer()));;
3218
3219
    info = new ForceInfo(nb.getNumForceBuffers(), force);
    cl.addForce(info);
3220
3221
}

3222
double OpenCLCalcGBSAOBCForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
3223
    OpenCLNonbondedUtilities& nb = cl.getNonbondedUtilities();
3224
    bool deviceIsCpu = (cl.getDevice().getInfo<CL_DEVICE_TYPE>() == CL_DEVICE_TYPE_CPU);
3225
3226
3227
3228
    if (!hasCreatedKernels) {
        // These Kernels cannot be created in initialize(), because the OpenCLNonbondedUtilities has not been initialized yet then.

        hasCreatedKernels = true;
3229
        maxTiles = (nb.getUseCutoff() ? nb.getInteractingTiles().getSize() : 0);
3230
3231
3232
3233
3234
        map<string, string> defines;
        if (nb.getUseCutoff())
            defines["USE_CUTOFF"] = "1";
        if (nb.getUsePeriodic())
            defines["USE_PERIODIC"] = "1";
3235
3236
        defines["CUTOFF_SQUARED"] = cl.doubleToString(cutoff*cutoff);
        defines["CUTOFF"] = cl.doubleToString(cutoff);
3237
        defines["PREFACTOR"] = cl.doubleToString(prefactor);
3238
        defines["SURFACE_AREA_FACTOR"] = cl.doubleToString(surfaceAreaFactor);
3239
3240
3241
3242
        defines["NUM_ATOMS"] = cl.intToString(cl.getNumAtoms());
        defines["PADDED_NUM_ATOMS"] = cl.intToString(cl.getPaddedNumAtoms());
        defines["NUM_BLOCKS"] = cl.intToString(cl.getNumAtomBlocks());
        defines["FORCE_WORK_GROUP_SIZE"] = cl.intToString(nb.getForceThreadBlockSize());
3243
3244
3245
3246
3247
3248
3249
3250
        defines["TILE_SIZE"] = cl.intToString(OpenCLContext::TileSize);
        int numExclusionTiles = nb.getExclusionTiles().getSize();
        defines["NUM_TILES_WITH_EXCLUSIONS"] = cl.intToString(numExclusionTiles);
        int numContexts = cl.getPlatformData().contexts.size();
        int startExclusionIndex = cl.getContextIndex()*numExclusionTiles/numContexts;
        int endExclusionIndex = (cl.getContextIndex()+1)*numExclusionTiles/numContexts;
        defines["FIRST_EXCLUSION_TILE"] = cl.intToString(startExclusionIndex);
        defines["LAST_EXCLUSION_TILE"] = cl.intToString(endExclusionIndex);
3251
3252
3253
        string platformVendor = cl::Platform(cl.getDevice().getInfo<CL_DEVICE_PLATFORM>()).getInfo<CL_PLATFORM_VENDOR>();
        if (platformVendor == "Apple")
            defines["USE_APPLE_WORKAROUND"] = "1";
3254
3255
3256
3257
        string file;
        if (deviceIsCpu)
            file = OpenCLKernelSources::gbsaObc_cpu;
        else
3258
            file = OpenCLKernelSources::gbsaObc;
3259
        cl::Program program = cl.createProgram(file, defines);
3260
        bool useLong = cl.getSupports64BitGlobalAtomics();
3261
        int index = 0;
3262
        computeBornSumKernel = cl::Kernel(program, "computeBornSum");
peastman's avatar
peastman committed
3263
        computeBornSumKernel.setArg<cl::Buffer>(index++, (useLong ? longBornSum.getDeviceBuffer() : bornSum.getDeviceBuffer()));
3264
        computeBornSumKernel.setArg<cl::Buffer>(index++, cl.getPosq().getDeviceBuffer());
3265
        computeBornSumKernel.setArg<cl::Buffer>(index++, charges.getDeviceBuffer());
peastman's avatar
peastman committed
3266
        computeBornSumKernel.setArg<cl::Buffer>(index++, params.getDeviceBuffer());
3267
        if (nb.getUseCutoff()) {
3268
3269
            computeBornSumKernel.setArg<cl::Buffer>(index++, nb.getInteractingTiles().getDeviceBuffer());
            computeBornSumKernel.setArg<cl::Buffer>(index++, nb.getInteractionCount().getDeviceBuffer());
3270
            index += 5; // The periodic box size arguments are set when the kernel is executed.
3271
            computeBornSumKernel.setArg<cl_uint>(index++, maxTiles);
3272
            computeBornSumKernel.setArg<cl::Buffer>(index++, nb.getBlockCenters().getDeviceBuffer());
3273
            computeBornSumKernel.setArg<cl::Buffer>(index++, nb.getBlockBoundingBoxes().getDeviceBuffer());
3274
            computeBornSumKernel.setArg<cl::Buffer>(index++, nb.getInteractingAtoms().getDeviceBuffer());
3275
        }
3276
3277
        else
            computeBornSumKernel.setArg<cl_uint>(index++, cl.getNumAtomBlocks()*(cl.getNumAtomBlocks()+1)/2);
3278
        computeBornSumKernel.setArg<cl::Buffer>(index++, nb.getExclusionTiles().getDeviceBuffer());
3279
        force1Kernel = cl::Kernel(program, "computeGBSAForce1");
3280
        index = 0;
3281
        force1Kernel.setArg<cl::Buffer>(index++, (useLong ? cl.getLongForceBuffer().getDeviceBuffer() : cl.getForceBuffers().getDeviceBuffer()));
peastman's avatar
peastman committed
3282
        force1Kernel.setArg<cl::Buffer>(index++, (useLong ? longBornForce.getDeviceBuffer() : bornForce.getDeviceBuffer()));
3283
3284
        force1Kernel.setArg<cl::Buffer>(index++, cl.getEnergyBuffer().getDeviceBuffer());
        force1Kernel.setArg<cl::Buffer>(index++, cl.getPosq().getDeviceBuffer());
3285
        force1Kernel.setArg<cl::Buffer>(index++, charges.getDeviceBuffer());
peastman's avatar
peastman committed
3286
        force1Kernel.setArg<cl::Buffer>(index++, bornRadii.getDeviceBuffer());
3287
        index++; // Whether to include energy.
3288
        if (nb.getUseCutoff()) {
3289
3290
            force1Kernel.setArg<cl::Buffer>(index++, nb.getInteractingTiles().getDeviceBuffer());
            force1Kernel.setArg<cl::Buffer>(index++, nb.getInteractionCount().getDeviceBuffer());
3291
            index += 5; // The periodic box size arguments are set when the kernel is executed.
3292
            force1Kernel.setArg<cl_uint>(index++, maxTiles);
3293
            force1Kernel.setArg<cl::Buffer>(index++, nb.getBlockCenters().getDeviceBuffer());
3294
            force1Kernel.setArg<cl::Buffer>(index++, nb.getBlockBoundingBoxes().getDeviceBuffer());
3295
            force1Kernel.setArg<cl::Buffer>(index++, nb.getInteractingAtoms().getDeviceBuffer());
3296
        }
3297
3298
        else
            force1Kernel.setArg<cl_uint>(index++, cl.getNumAtomBlocks()*(cl.getNumAtomBlocks()+1)/2);
3299
        force1Kernel.setArg<cl::Buffer>(index++, nb.getExclusionTiles().getDeviceBuffer());
3300
        program = cl.createProgram(OpenCLKernelSources::gbsaObcReductions, defines);
3301
3302
        reduceBornSumKernel = cl::Kernel(program, "reduceBornSum");
        reduceBornSumKernel.setArg<cl_int>(0, cl.getPaddedNumAtoms());
Peter Eastman's avatar
Peter Eastman committed
3303
        reduceBornSumKernel.setArg<cl_int>(1, nb.getNumForceBuffers());
3304
3305
3306
        reduceBornSumKernel.setArg<cl_float>(2, 1.0f);
        reduceBornSumKernel.setArg<cl_float>(3, 0.8f);
        reduceBornSumKernel.setArg<cl_float>(4, 4.85f);
peastman's avatar
peastman committed
3307
3308
3309
3310
        reduceBornSumKernel.setArg<cl::Buffer>(5, (useLong ? longBornSum.getDeviceBuffer() : bornSum.getDeviceBuffer()));
        reduceBornSumKernel.setArg<cl::Buffer>(6, params.getDeviceBuffer());
        reduceBornSumKernel.setArg<cl::Buffer>(7, bornRadii.getDeviceBuffer());
        reduceBornSumKernel.setArg<cl::Buffer>(8, obcChain.getDeviceBuffer());
3311
        reduceBornForceKernel = cl::Kernel(program, "reduceBornForce");
3312
3313
3314
        index = 0;
        reduceBornForceKernel.setArg<cl_int>(index++, cl.getPaddedNumAtoms());
        reduceBornForceKernel.setArg<cl_int>(index++, nb.getNumForceBuffers());
peastman's avatar
peastman committed
3315
        reduceBornForceKernel.setArg<cl::Buffer>(index++, bornForce.getDeviceBuffer());
3316
        if (useLong)
peastman's avatar
peastman committed
3317
            reduceBornForceKernel.setArg<cl::Buffer>(index++, longBornForce.getDeviceBuffer());
3318
        reduceBornForceKernel.setArg<cl::Buffer>(index++, cl.getEnergyBuffer().getDeviceBuffer());
peastman's avatar
peastman committed
3319
3320
3321
        reduceBornForceKernel.setArg<cl::Buffer>(index++, params.getDeviceBuffer());
        reduceBornForceKernel.setArg<cl::Buffer>(index++, bornRadii.getDeviceBuffer());
        reduceBornForceKernel.setArg<cl::Buffer>(index++, obcChain.getDeviceBuffer());
3322
    }
3323
    force1Kernel.setArg<cl_int>(6, includeEnergy);
3324
    if (nb.getUseCutoff()) {
3325
3326
        setPeriodicBoxArgs(cl, computeBornSumKernel, 6);
        setPeriodicBoxArgs(cl, force1Kernel, 9);
3327
3328
        if (maxTiles < nb.getInteractingTiles().getSize()) {
            maxTiles = nb.getInteractingTiles().getSize();
Peter Eastman's avatar
Bug fix  
Peter Eastman committed
3329
            computeBornSumKernel.setArg<cl::Buffer>(4, nb.getInteractingTiles().getDeviceBuffer());
3330
3331
3332
3333
3334
            computeBornSumKernel.setArg<cl_uint>(11, maxTiles);
            computeBornSumKernel.setArg<cl::Buffer>(14, nb.getInteractingAtoms().getDeviceBuffer());
            force1Kernel.setArg<cl::Buffer>(7, nb.getInteractingTiles().getDeviceBuffer());
            force1Kernel.setArg<cl_uint>(14, maxTiles);
            force1Kernel.setArg<cl::Buffer>(17, nb.getInteractingAtoms().getDeviceBuffer());
3335
        }
3336
    }
3337
    cl.executeKernel(computeBornSumKernel, nb.getNumForceThreadBlocks()*nb.getForceThreadBlockSize(), nb.getForceThreadBlockSize());
3338
    cl.executeKernel(reduceBornSumKernel, cl.getPaddedNumAtoms());
3339
    cl.executeKernel(force1Kernel, nb.getNumForceThreadBlocks()*nb.getForceThreadBlockSize(), nb.getForceThreadBlockSize());
3340
    cl.executeKernel(reduceBornForceKernel, cl.getPaddedNumAtoms());
3341
    return 0.0;
3342
}
3343

3344
3345
3346
3347
3348
3349
3350
3351
3352
void OpenCLCalcGBSAOBCForceKernel::copyParametersToContext(ContextImpl& context, const GBSAOBCForce& force) {
    // Make sure the new parameters are acceptable.
    
    int numParticles = force.getNumParticles();
    if (numParticles != cl.getNumAtoms())
        throw OpenMMException("updateParametersInContext: The number of particles has changed");
    
    // Record the per-particle parameters.
    
3353
    vector<double> chargeVector(cl.getPaddedNumAtoms(), 0.0);
3354
    vector<mm_float2> paramsVector(cl.getPaddedNumAtoms());
3355
3356
3357
3358
    const double dielectricOffset = 0.009;
    for (int i = 0; i < numParticles; i++) {
        double charge, radius, scalingFactor;
        force.getParticleParameters(i, charge, radius, scalingFactor);
3359
        chargeVector[i] = charge;
3360
3361
3362
        radius -= dielectricOffset;
        paramsVector[i] = mm_float2((float) radius, (float) (scalingFactor*radius));
    }
3363
3364
    for (int i = numParticles; i < cl.getPaddedNumAtoms(); i++)
        paramsVector[i] = mm_float2(1,1);
peastman's avatar
peastman committed
3365
    charges.upload(chargeVector, true, true);
peastman's avatar
peastman committed
3366
    params.upload(paramsVector);
3367
3368
3369
    
    // Mark that the current reordering may be invalid.
    
3370
    cl.invalidateMolecules(info);
3371
3372
}

3373
class OpenCLCalcCustomGBForceKernel::ForceInfo : public OpenCLForceInfo {
3374
public:
3375
    ForceInfo(int requiredBuffers, const CustomGBForce& force) : OpenCLForceInfo(requiredBuffers), force(force) {
3376
3377
3378
3379
3380
3381
    }
    bool areParticlesIdentical(int particle1, int particle2) {
        vector<double> params1;
        vector<double> params2;
        force.getParticleParameters(particle1, params1);
        force.getParticleParameters(particle2, params2);
3382
        for (int i = 0; i < (int) params1.size(); i++)
3383
3384
3385
3386
3387
3388
3389
            if (params1[i] != params2[i])
                return false;
        return true;
    }
    int getNumParticleGroups() {
        return force.getNumExclusions();
    }
Peter Eastman's avatar
Peter Eastman committed
3390
    void getParticlesInGroup(int index, vector<int>& particles) {
3391
3392
3393
3394
3395
3396
3397
3398
3399
3400
3401
3402
3403
3404
3405
3406
3407
3408
        int particle1, particle2;
        force.getExclusionParticles(index, particle1, particle2);
        particles.resize(2);
        particles[0] = particle1;
        particles[1] = particle2;
    }
    bool areGroupsIdentical(int group1, int group2) {
        return true;
    }
private:
    const CustomGBForce& force;
};

OpenCLCalcCustomGBForceKernel::~OpenCLCalcCustomGBForceKernel() {
    if (params != NULL)
        delete params;
    if (computedValues != NULL)
        delete computedValues;
3409
3410
    if (energyDerivs != NULL)
        delete energyDerivs;
3411
3412
    if (energyDerivChain != NULL)
        delete energyDerivChain;
peastman's avatar
peastman committed
3413
3414
    for (auto d : dValuedParam)
        delete d;
3415
3416
3417
}

void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const CustomGBForce& force) {
3418
3419
    if (cl.getPlatformData().contexts.size() > 1)
        throw OpenMMException("CustomGBForce does not support using multiple OpenCL devices");
3420
    cutoff = force.getCutoffDistance();
3421
    bool useExclusionsForValue = false;
3422
    numComputedValues = force.getNumComputedValues();
3423
3424
    vector<string> computedValueNames(force.getNumComputedValues());
    vector<string> computedValueExpressions(force.getNumComputedValues());
3425
3426
    if (force.getNumComputedValues() > 0) {
        CustomGBForce::ComputationType type;
3427
        force.getComputedValueParameters(0, computedValueNames[0], computedValueExpressions[0], type);
3428
3429
3430
3431
        if (type == CustomGBForce::SingleParticle)
            throw OpenMMException("OpenCLPlatform requires that the first computed value for a CustomGBForce be of type ParticlePair or ParticlePairNoExclusions.");
        useExclusionsForValue = (type == CustomGBForce::ParticlePair);
        for (int i = 1; i < force.getNumComputedValues(); i++) {
3432
            force.getComputedValueParameters(i, computedValueNames[i], computedValueExpressions[i], type);
3433
3434
3435
3436
3437
3438
3439
            if (type != CustomGBForce::SingleParticle)
                throw OpenMMException("OpenCLPlatform requires that a CustomGBForce only have one computed value of type ParticlePair or ParticlePairNoExclusions.");
        }
    }
    int forceIndex;
    for (forceIndex = 0; forceIndex < system.getNumForces() && &system.getForce(forceIndex) != &force; ++forceIndex)
        ;
3440
    string prefix = "custom"+cl.intToString(forceIndex)+"_";
3441
3442
3443
3444

    // Record parameters and exclusions.

    int numParticles = force.getNumParticles();
3445
3446
3447
3448
    int paddedNumParticles = cl.getPaddedNumAtoms();
    int numParams = force.getNumPerParticleParameters();
    params = new OpenCLParameterSet(cl, force.getNumPerParticleParameters(), paddedNumParticles, "customGBParameters", true);
    computedValues = new OpenCLParameterSet(cl, force.getNumComputedValues(), paddedNumParticles, "customGBComputedValues", true, cl.getUseDoublePrecision());
3449
    if (force.getNumGlobalParameters() > 0)
peastman's avatar
peastman committed
3450
        globals.initialize<cl_float>(cl, force.getNumGlobalParameters(), "customGBGlobals", CL_MEM_READ_ONLY);
3451
    vector<vector<cl_float> > paramVector(paddedNumParticles, vector<cl_float>(numParams, 0));
3452
3453
3454
3455
    vector<vector<int> > exclusionList(numParticles);
    for (int i = 0; i < numParticles; i++) {
        vector<double> parameters;
        force.getParticleParameters(i, parameters);
3456
        for (int j = 0; j < (int) parameters.size(); j++)
3457
3458
3459
3460
3461
3462
3463
3464
3465
3466
3467
3468
3469
3470
3471
            paramVector[i][j] = (cl_float) parameters[j];
        exclusionList[i].push_back(i);
    }
    for (int i = 0; i < force.getNumExclusions(); i++) {
        int particle1, particle2;
        force.getExclusionParticles(i, particle1, particle2);
        exclusionList[particle1].push_back(particle2);
        exclusionList[particle2].push_back(particle1);
    }
    params->setParameterValues(paramVector);

    // Record the tabulated functions.

    map<string, Lepton::CustomFunction*> functions;
    vector<pair<string, string> > functionDefinitions;
3472
    vector<const TabulatedFunction*> functionList;
3473
    stringstream tableArgs;
peastman's avatar
peastman committed
3474
    tabulatedFunctions.resize(force.getNumTabulatedFunctions());
3475
3476
3477
    for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
        functionList.push_back(&force.getTabulatedFunction(i));
        string name = force.getTabulatedFunctionName(i);
3478
        string arrayName = prefix+"table"+cl.intToString(i);
3479
        functionDefinitions.push_back(make_pair(name, arrayName));
3480
        functions[name] = cl.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
peastman's avatar
peastman committed
3481
        int width;
3482
        vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
peastman's avatar
peastman committed
3483
3484
3485
        tabulatedFunctions[i].initialize<float>(cl, f.size(), "TabulatedFunction");
        tabulatedFunctions[i].upload(f);
        cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(arrayName, "float", width, width*sizeof(float), tabulatedFunctions[i].getDeviceBuffer()));
3486
3487
3488
3489
        tableArgs << ", __global const float";
        if (width > 1)
            tableArgs << width;
        tableArgs << "* restrict " << arrayName;
3490
3491
    }

3492
    // Record the global parameters.
3493
3494
3495
3496
3497
3498
3499

    globalParamNames.resize(force.getNumGlobalParameters());
    globalParamValues.resize(force.getNumGlobalParameters());
    for (int i = 0; i < force.getNumGlobalParameters(); i++) {
        globalParamNames[i] = force.getGlobalParameterName(i);
        globalParamValues[i] = (cl_float) force.getGlobalParameterDefaultValue(i);
    }
peastman's avatar
peastman committed
3500
3501
    if (globals.isInitialized())
        globals.upload(globalParamValues);
3502
3503
3504

    // Record derivatives of expressions needed for the chain rule terms.

3505
    vector<vector<Lepton::ParsedExpression> > valueGradientExpressions(force.getNumComputedValues());
3506
    vector<vector<Lepton::ParsedExpression> > valueDerivExpressions(force.getNumComputedValues());
3507
    vector<vector<Lepton::ParsedExpression> > valueParamDerivExpressions(force.getNumComputedValues());
Peter Eastman's avatar
Peter Eastman committed
3508
    needParameterGradient = false;
3509
    for (int i = 0; i < force.getNumComputedValues(); i++) {
3510
        Lepton::ParsedExpression ex = Lepton::Parser::parse(computedValueExpressions[i], functions).optimize();
3511
3512
3513
3514
3515
3516
3517
3518
3519
3520
3521
        if (i > 0) {
            valueGradientExpressions[i].push_back(ex.differentiate("x").optimize());
            valueGradientExpressions[i].push_back(ex.differentiate("y").optimize());
            valueGradientExpressions[i].push_back(ex.differentiate("z").optimize());
            if (!isZeroExpression(valueGradientExpressions[i][0]) || !isZeroExpression(valueGradientExpressions[i][1]) || !isZeroExpression(valueGradientExpressions[i][2]))
                needParameterGradient = true;
             for (int j = 0; j < i; j++)
                valueDerivExpressions[i].push_back(ex.differentiate(computedValueNames[j]).optimize());
        }
        for (int j = 0; j < force.getNumEnergyParameterDerivatives(); j++)
            valueParamDerivExpressions[i].push_back(ex.differentiate(force.getEnergyParameterDerivativeName(j)).optimize());
3522
    }
3523
    vector<vector<Lepton::ParsedExpression> > energyDerivExpressions(force.getNumEnergyTerms());
3524
    vector<vector<Lepton::ParsedExpression> > energyParamDerivExpressions(force.getNumEnergyTerms());
Peter Eastman's avatar
Peter Eastman committed
3525
    vector<bool> needChainForValue(force.getNumComputedValues(), false);
3526
3527
3528
3529
3530
3531
    for (int i = 0; i < force.getNumEnergyTerms(); i++) {
        string expression;
        CustomGBForce::ComputationType type;
        force.getEnergyTermParameters(i, expression, type);
        Lepton::ParsedExpression ex = Lepton::Parser::parse(expression, functions).optimize();
        for (int j = 0; j < force.getNumComputedValues(); j++) {
Peter Eastman's avatar
Peter Eastman committed
3532
            if (type == CustomGBForce::SingleParticle) {
3533
                energyDerivExpressions[i].push_back(ex.differentiate(computedValueNames[j]).optimize());
Peter Eastman's avatar
Peter Eastman committed
3534
3535
3536
                if (!isZeroExpression(energyDerivExpressions[i].back()))
                    needChainForValue[j] = true;
            }
3537
3538
            else {
                energyDerivExpressions[i].push_back(ex.differentiate(computedValueNames[j]+"1").optimize());
Peter Eastman's avatar
Peter Eastman committed
3539
3540
                if (!isZeroExpression(energyDerivExpressions[i].back()))
                    needChainForValue[j] = true;
3541
                energyDerivExpressions[i].push_back(ex.differentiate(computedValueNames[j]+"2").optimize());
Peter Eastman's avatar
Peter Eastman committed
3542
3543
                if (!isZeroExpression(energyDerivExpressions[i].back()))
                    needChainForValue[j] = true;
3544
3545
            }
        }
3546
3547
        for (int j = 0; j < force.getNumEnergyParameterDerivatives(); j++)
            energyParamDerivExpressions[i].push_back(ex.differentiate(force.getEnergyParameterDerivativeName(j)).optimize());
3548
    }
3549
    bool deviceIsCpu = (cl.getDevice().getInfo<CL_DEVICE_TYPE>() == CL_DEVICE_TYPE_CPU);
3550
    bool useLong = cl.getSupports64BitGlobalAtomics();
3551
    if (useLong) {
peastman's avatar
peastman committed
3552
        longEnergyDerivs.initialize<cl_long>(cl, force.getNumComputedValues()*cl.getPaddedNumAtoms(), "customGBLongEnergyDerivatives");
Peter Eastman's avatar
Peter Eastman committed
3553
        energyDerivs = new OpenCLParameterSet(cl, force.getNumComputedValues(), cl.getPaddedNumAtoms(), "customGBEnergyDerivatives", true);
3554
3555
    }
    else
Peter Eastman's avatar
Peter Eastman committed
3556
        energyDerivs = new OpenCLParameterSet(cl, force.getNumComputedValues(), cl.getPaddedNumAtoms()*cl.getNonbondedUtilities().getNumForceBuffers(), "customGBEnergyDerivatives", true);
3557
    energyDerivChain = new OpenCLParameterSet(cl, force.getNumComputedValues(), cl.getPaddedNumAtoms(), "customGBEnergyDerivativeChain", true);
3558
3559
    int elementSize = (cl.getUseDoublePrecision() ? sizeof(cl_double) : sizeof(cl_float));
    needEnergyParamDerivs = (force.getNumEnergyParameterDerivatives() > 0);
peastman's avatar
peastman committed
3560
    dValue0dParam.resize(force.getNumEnergyParameterDerivatives());
3561
    for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
3562
        dValuedParam.push_back(new OpenCLParameterSet(cl, force.getNumComputedValues(), cl.getPaddedNumAtoms(), "dValuedParam", true, cl.getUseDoublePrecision()));
3563
        if (useLong)
peastman's avatar
peastman committed
3564
            dValue0dParam[i].initialize<cl_long>(cl, cl.getPaddedNumAtoms(), "dValue0dParam");
3565
        else
peastman's avatar
peastman committed
3566
3567
            dValue0dParam[i].initialize(cl, cl.getPaddedNumAtoms()*cl.getNonbondedUtilities().getNumForceBuffers(), elementSize, "dValue0dParam");
        cl.addAutoclearBuffer(dValue0dParam[i]);
3568
3569
3570
        string name = force.getEnergyParameterDerivativeName(i);
        cl.addEnergyParameterDerivative(name);
    }
3571

3572
3573
    // Create the kernels.

3574
3575
    bool useCutoff = (force.getNonbondedMethod() != CustomGBForce::NoCutoff);
    bool usePeriodic = (force.getNonbondedMethod() != CustomGBForce::NoCutoff && force.getNonbondedMethod() != CustomGBForce::CutoffNonPeriodic);
3576
3577
3578
    {
        // Create the N2 value kernel.

3579
        vector<pair<ExpressionTreeNode, string> > variables;
3580
        map<string, string> rename;
3581
3582
3583
3584
        ExpressionTreeNode rnode(new Operation::Variable("r"));
        variables.push_back(make_pair(rnode, "r"));
        variables.push_back(make_pair(ExpressionTreeNode(new Operation::Square(), rnode), "r2"));
        variables.push_back(make_pair(ExpressionTreeNode(new Operation::Reciprocal(), rnode), "invR"));
3585
3586
        for (int i = 0; i < force.getNumPerParticleParameters(); i++) {
            const string& name = force.getPerParticleParameterName(i);
3587
3588
            variables.push_back(makeVariable(name+"1", "params"+params->getParameterSuffix(i, "1")));
            variables.push_back(makeVariable(name+"2", "params"+params->getParameterSuffix(i, "2")));
3589
3590
            rename[name+"1"] = name+"2";
            rename[name+"2"] = name+"1";
3591
3592
3593
        }
        for (int i = 0; i < force.getNumGlobalParameters(); i++) {
            const string& name = force.getGlobalParameterName(i);
3594
            string value = "globals["+cl.intToString(i)+"]";
3595
            variables.push_back(makeVariable(name, value));
3596
        }
3597
3598
        map<string, Lepton::ParsedExpression> n2ValueExpressions;
        stringstream n2ValueSource;
3599
3600
3601
        Lepton::ParsedExpression ex = Lepton::Parser::parse(computedValueExpressions[0], functions).optimize();
        n2ValueExpressions["tempValue1 = "] = ex;
        n2ValueExpressions["tempValue2 = "] = ex.renameVariables(rename);
3602
3603
3604
3605
3606
3607
3608
        for (int i = 0; i < valueParamDerivExpressions[0].size(); i++) {
            string variableBase = "temp_dValue0dParam"+cl.intToString(i+1);
            if (!isZeroExpression(valueParamDerivExpressions[0][i])) {
                n2ValueExpressions[variableBase+"_1 = "] = valueParamDerivExpressions[0][i];
                n2ValueExpressions[variableBase+"_2 = "] = valueParamDerivExpressions[0][i].renameVariables(rename);
            }
        }
3609
        n2ValueSource << cl.getExpressionUtilities().createExpressions(n2ValueExpressions, variables, functionList, functionDefinitions, "temp");
3610
        map<string, string> replacements;
Peter Eastman's avatar
Peter Eastman committed
3611
3612
        string n2ValueStr = n2ValueSource.str();
        replacements["COMPUTE_VALUE"] = n2ValueStr;
3613
        stringstream extraArgs, loadLocal1, loadLocal2, load1, load2, tempDerivs1, tempDerivs2, storeDeriv1, storeDeriv2;
3614
        if (force.getNumGlobalParameters() > 0)
Peter Eastman's avatar
Peter Eastman committed
3615
            extraArgs << ", __global const float* globals";
Peter Eastman's avatar
Peter Eastman committed
3616
        pairValueUsesParam.resize(params->getBuffers().size(), false);
3617
3618
        for (int i = 0; i < (int) params->getBuffers().size(); i++) {
            const OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i];
3619
            string paramName = "params"+cl.intToString(i+1);
Peter Eastman's avatar
Peter Eastman committed
3620
3621
3622
3623
3624
3625
3626
3627
            if (n2ValueStr.find(paramName+"1") != n2ValueStr.npos || n2ValueStr.find(paramName+"2") != n2ValueStr.npos) {
                extraArgs << ", __global const " << buffer.getType() << "* restrict global_" << paramName << ", __local " << buffer.getType() << "* restrict local_" << paramName;
                loadLocal1 << "local_" << paramName << "[localAtomIndex] = " << paramName << "1;\n";
                loadLocal2 << "local_" << paramName << "[localAtomIndex] = global_" << paramName << "[j];\n";
                load1 << buffer.getType() << " " << paramName << "1 = global_" << paramName << "[atom1];\n";
                load2 << buffer.getType() << " " << paramName << "2 = local_" << paramName << "[atom2];\n";
                pairValueUsesParam[i] = true;
            }
3628
        }
3629
3630
3631
3632
3633
3634
3635
3636
3637
3638
3639
3640
3641
        for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
            string derivName = "dValue0dParam"+cl.intToString(i+1);
            if (useLong)
                extraArgs << ", __global long* restrict global_" << derivName;
            else
                extraArgs << ", __global real* restrict global_" << derivName;
            extraArgs << ", __local real* restrict local_" << derivName;
            loadLocal2 << "local_" << derivName << "[localAtomIndex] = 0;\n";
            load1 << "real " << derivName << " = 0;\n";
            if (!isZeroExpression(valueParamDerivExpressions[0][i])) {
                load2 << "real temp_" << derivName << "_1 = 0;\n";
                load2 << "real temp_" << derivName << "_2 = 0;\n";
                tempDerivs1 << derivName << " += temp_" << derivName << "_1;\n";
peastman's avatar
peastman committed
3642
3643
3644
3645
                if (deviceIsCpu)
                    tempDerivs2 << "local_" << derivName << "[j] += temp_" << derivName << "_2;\n";
                else
                    tempDerivs2 << "local_" << derivName << "[tbx+tj] += temp_" << derivName << "_2;\n";
3646
3647
                if (useLong) {
                    storeDeriv1 << "atom_add(&global_" << derivName << "[offset1], (long) (" << derivName << "*0x100000000));\n";
peastman's avatar
peastman committed
3648
3649
3650
3651
                    if (deviceIsCpu)
                        storeDeriv2 << "atom_add(&global_" << derivName << "[offset2], (long) (local_" << derivName << "[tgx]*0x100000000));\n";
                    else
                        storeDeriv2 << "atom_add(&global_" << derivName << "[offset2], (long) (local_" << derivName << "[get_local_id(0)]*0x100000000));\n";
3652
3653
3654
                }
                else {
                    storeDeriv1 << "global_" << derivName << "[offset1] += " << derivName << ";\n";
peastman's avatar
peastman committed
3655
3656
3657
3658
                    if (deviceIsCpu)
                        storeDeriv2 << "global_" << derivName << "[offset2] += local_" << derivName << "[tgx];\n";
                    else
                        storeDeriv2 << "global_" << derivName << "[offset2] += local_" << derivName << "[get_local_id(0)];\n";
3659
3660
3661
                }
            }
        }
3662
        replacements["PARAMETER_ARGUMENTS"] = extraArgs.str()+tableArgs.str();
3663
3664
3665
3666
        replacements["LOAD_LOCAL_PARAMETERS_FROM_1"] = loadLocal1.str();
        replacements["LOAD_LOCAL_PARAMETERS_FROM_GLOBAL"] = loadLocal2.str();
        replacements["LOAD_ATOM1_PARAMETERS"] = load1.str();
        replacements["LOAD_ATOM2_PARAMETERS"] = load2.str();
3667
3668
3669
3670
        replacements["ADD_TEMP_DERIVS1"] = tempDerivs1.str();
        replacements["ADD_TEMP_DERIVS2"] = tempDerivs2.str();
        replacements["STORE_PARAM_DERIVS1"] = storeDeriv1.str();
        replacements["STORE_PARAM_DERIVS2"] = storeDeriv2.str();
3671
        if (useCutoff)
3672
            pairValueDefines["USE_CUTOFF"] = "1";
3673
        if (usePeriodic)
3674
            pairValueDefines["USE_PERIODIC"] = "1";
3675
        if (useExclusionsForValue)
3676
3677
            pairValueDefines["USE_EXCLUSIONS"] = "1";
        pairValueDefines["FORCE_WORK_GROUP_SIZE"] = cl.intToString(cl.getNonbondedUtilities().getForceThreadBlockSize());
3678
        pairValueDefines["CUTOFF_SQUARED"] = cl.doubleToString(cutoff*cutoff);
3679
3680
3681
3682
        pairValueDefines["NUM_ATOMS"] = cl.intToString(cl.getNumAtoms());
        pairValueDefines["PADDED_NUM_ATOMS"] = cl.intToString(cl.getPaddedNumAtoms());
        pairValueDefines["NUM_BLOCKS"] = cl.intToString(cl.getNumAtomBlocks());
        pairValueDefines["TILE_SIZE"] = cl.intToString(OpenCLContext::TileSize);
3683
3684
3685
3686
        string file;
        if (deviceIsCpu)
            file = OpenCLKernelSources::customGBValueN2_cpu;
        else
3687
3688
            file = OpenCLKernelSources::customGBValueN2;
        pairValueSrc = cl.replaceStrings(file, replacements);
3689
3690
        if (useExclusionsForValue)
            cl.getNonbondedUtilities().requestExclusions(exclusionList);
3691
3692
3693
3694
    }
    {
        // Create the kernel to reduce the N2 value and calculate other values.

3695
        stringstream reductionSource, extraArgs, deriv0;
3696
        if (force.getNumGlobalParameters() > 0)
Peter Eastman's avatar
Peter Eastman committed
3697
            extraArgs << ", __global const float* globals";
3698
3699
        for (int i = 0; i < (int) params->getBuffers().size(); i++) {
            const OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i];
3700
            string paramName = "params"+cl.intToString(i+1);
3701
            extraArgs << ", __global const " << buffer.getType() << "* restrict " << paramName;
3702
3703
3704
        }
        for (int i = 0; i < (int) computedValues->getBuffers().size(); i++) {
            const OpenCLNonbondedUtilities::ParameterInfo& buffer = computedValues->getBuffers()[i];
3705
            string valueName = "values"+cl.intToString(i+1);
3706
            extraArgs << ", __global " << buffer.getType() << "* restrict global_" << valueName;
3707
3708
            reductionSource << buffer.getType() << " local_" << valueName << ";\n";
        }
3709
3710
3711
3712
        for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
            string variableName = "dValuedParam_0_"+cl.intToString(i);
            if (useLong) {
                extraArgs << ", __global const long* restrict dValue0dParam" << i;
3713
                deriv0 << "real " << variableName << " = (1.0f/0x100000000)*dValue0dParam" << i << "[index];\n";
3714
3715
3716
3717
3718
3719
3720
3721
3722
3723
3724
            }
            else {
                extraArgs << ", __global const real* restrict dValue0dParam" << i;
                deriv0 << "real " << variableName << " = dValue0dParam" << i << "[index];\n";
                deriv0 << "for (int i = index+bufferSize; i < totalSize; i += bufferSize)\n";
                deriv0 << "    " << variableName << " += dValue0dParam" << i << "[i];\n";
            }
            for (int j = 0; j < dValuedParam[i]->getBuffers().size(); j++)
                extraArgs << ", __global real* restrict global_dValuedParam_" << j << "_" << i;
            deriv0 << "global_dValuedParam_0_" << i << "[index] = dValuedParam_0_" << i << ";\n";
        }
3725
        reductionSource << "local_values" << computedValues->getParameterSuffix(0) << " = sum;\n";
3726
        map<string, string> variables;
3727
3728
3729
        variables["x"] = "pos.x";
        variables["y"] = "pos.y";
        variables["z"] = "pos.z";
3730
3731
3732
        for (int i = 0; i < force.getNumPerParticleParameters(); i++)
            variables[force.getPerParticleParameterName(i)] = "params"+params->getParameterSuffix(i, "[index]");
        for (int i = 0; i < force.getNumGlobalParameters(); i++)
3733
            variables[force.getGlobalParameterName(i)] = "globals["+cl.intToString(i)+"]";
3734
3735
3736
3737
        for (int i = 1; i < force.getNumComputedValues(); i++) {
            variables[computedValueNames[i-1]] = "local_values"+computedValues->getParameterSuffix(i-1);
            map<string, Lepton::ParsedExpression> valueExpressions;
            valueExpressions["local_values"+computedValues->getParameterSuffix(i)+" = "] = Lepton::Parser::parse(computedValueExpressions[i], functions).optimize();
3738
            reductionSource << cl.getExpressionUtilities().createExpressions(valueExpressions, variables, functionList, functionDefinitions, "value"+cl.intToString(i)+"_temp");
3739
        }
3740
        for (int i = 0; i < (int) computedValues->getBuffers().size(); i++) {
3741
            string valueName = "values"+cl.intToString(i+1);
3742
3743
            reductionSource << "global_" << valueName << "[index] = local_" << valueName << ";\n";
        }
3744
3745
3746
3747
3748
3749
3750
3751
3752
3753
3754
3755
3756
3757
3758
3759
3760
        if (needEnergyParamDerivs) {
            map<string, Lepton::ParsedExpression> derivExpressions;
            for (int i = 1; i < force.getNumComputedValues(); i++) {
                for (int j = 0; j < valueParamDerivExpressions[i].size(); j++)
                    derivExpressions["real dValuedParam_"+cl.intToString(i)+"_"+cl.intToString(j)+" = "] = valueParamDerivExpressions[i][j];
                for (int j = 0; j < i; j++)
                    derivExpressions["real dVdV_"+cl.intToString(i)+"_"+cl.intToString(j)+" = "] = valueDerivExpressions[i][j];
            }
            reductionSource << cl.getExpressionUtilities().createExpressions(derivExpressions, variables, functionList, functionDefinitions, "derivChain_temp");
            for (int i = 1; i < force.getNumComputedValues(); i++) {
                for (int j = 0; j < i; j++)
                    for (int k = 0; k < valueParamDerivExpressions[i].size(); k++)
                        reductionSource << "dValuedParam_" << i << "_" << k << " += dVdV_" << i << "_" << j << "*dValuedParam_" << j <<"_" << k << ";\n";
                for (int j = 0; j < valueParamDerivExpressions[i].size(); j++)
                    reductionSource << "global_dValuedParam_" << i << "_" << j << "[index] = dValuedParam_" << i << "_" << j << ";\n";
            }
        }
3761
        map<string, string> replacements;
3762
        replacements["PARAMETER_ARGUMENTS"] = extraArgs.str()+tableArgs.str();
3763
        replacements["REDUCE_PARAM0_DERIV"] = deriv0.str();
3764
3765
        replacements["COMPUTE_VALUES"] = reductionSource.str();
        map<string, string> defines;
3766
        defines["NUM_ATOMS"] = cl.intToString(cl.getNumAtoms());
3767
        cl::Program program = cl.createProgram(cl.replaceStrings(OpenCLKernelSources::customGBValuePerParticle, replacements), defines);
3768
3769
3770
3771
3772
        perParticleValueKernel = cl::Kernel(program, "computePerParticleValues");
    }
    {
        // Create the N2 energy kernel.

3773
3774
3775
3776
3777
        vector<pair<ExpressionTreeNode, string> > variables;
        ExpressionTreeNode rnode(new Operation::Variable("r"));
        variables.push_back(make_pair(rnode, "r"));
        variables.push_back(make_pair(ExpressionTreeNode(new Operation::Square(), rnode), "r2"));
        variables.push_back(make_pair(ExpressionTreeNode(new Operation::Reciprocal(), rnode), "invR"));
3778
3779
        for (int i = 0; i < force.getNumPerParticleParameters(); i++) {
            const string& name = force.getPerParticleParameterName(i);
3780
3781
            variables.push_back(makeVariable(name+"1", "params"+params->getParameterSuffix(i, "1")));
            variables.push_back(makeVariable(name+"2", "params"+params->getParameterSuffix(i, "2")));
3782
3783
        }
        for (int i = 0; i < force.getNumComputedValues(); i++) {
3784
3785
            variables.push_back(makeVariable(computedValueNames[i]+"1", "values"+computedValues->getParameterSuffix(i, "1")));
            variables.push_back(makeVariable(computedValueNames[i]+"2", "values"+computedValues->getParameterSuffix(i, "2")));
3786
3787
        }
        for (int i = 0; i < force.getNumGlobalParameters(); i++)
3788
            variables.push_back(makeVariable(force.getGlobalParameterName(i), "globals["+cl.intToString(i)+"]"));
3789
        stringstream n2EnergySource;
3790
        bool anyExclusions = (force.getNumExclusions() > 0);
3791
3792
3793
3794
3795
3796
        for (int i = 0; i < force.getNumEnergyTerms(); i++) {
            string expression;
            CustomGBForce::ComputationType type;
            force.getEnergyTermParameters(i, expression, type);
            if (type == CustomGBForce::SingleParticle)
                continue;
3797
            bool exclude = (anyExclusions && type == CustomGBForce::ParticlePair);
3798
            map<string, Lepton::ParsedExpression> n2EnergyExpressions;
3799
3800
            n2EnergyExpressions["tempEnergy += "] = Lepton::Parser::parse(expression, functions).optimize();
            n2EnergyExpressions["dEdR += "] = Lepton::Parser::parse(expression, functions).differentiate("r").optimize();
3801
3802
            if (useLong) {
                for (int j = 0; j < force.getNumComputedValues(); j++) {
Peter Eastman's avatar
Peter Eastman committed
3803
                    if (needChainForValue[j]) {
3804
3805
3806
                        string index = cl.intToString(j+1);
                        n2EnergyExpressions["/*"+cl.intToString(i+1)+"*/ deriv"+index+"_1 += "] = energyDerivExpressions[i][2*j];
                        n2EnergyExpressions["/*"+cl.intToString(i+1)+"*/ deriv"+index+"_2 += "] = energyDerivExpressions[i][2*j+1];
Peter Eastman's avatar
Peter Eastman committed
3807
                    }
3808
3809
3810
3811
                }
            }
            else {
                for (int j = 0; j < force.getNumComputedValues(); j++) {
Peter Eastman's avatar
Peter Eastman committed
3812
                    if (needChainForValue[j]) {
3813
3814
                        n2EnergyExpressions["/*"+cl.intToString(i+1)+"*/ deriv"+energyDerivs->getParameterSuffix(j, "_1")+" += "] = energyDerivExpressions[i][2*j];
                        n2EnergyExpressions["/*"+cl.intToString(i+1)+"*/ deriv"+energyDerivs->getParameterSuffix(j, "_2")+" += "] = energyDerivExpressions[i][2*j+1];
Peter Eastman's avatar
Peter Eastman committed
3815
                    }
3816
                }
3817
            }
3818
3819
            for (int j = 0; j < force.getNumEnergyParameterDerivatives(); j++)
                n2EnergyExpressions["energyParamDeriv"+cl.intToString(j)+" += interactionScale*"] = energyParamDerivExpressions[i][j];
3820
3821
            if (exclude)
                n2EnergySource << "if (!isExcluded) {\n";
3822
            n2EnergySource << cl.getExpressionUtilities().createExpressions(n2EnergyExpressions, variables, functionList, functionDefinitions, "temp");
3823
3824
            if (exclude)
                n2EnergySource << "}\n";
3825
3826
        }
        map<string, string> replacements;
Peter Eastman's avatar
Peter Eastman committed
3827
3828
        string n2EnergyStr = n2EnergySource.str();
        replacements["COMPUTE_INTERACTION"] = n2EnergyStr;
3829
        stringstream extraArgs, loadLocal1, loadLocal2, clearLocal, load1, load2, declare1, recordDeriv, storeDerivs1, storeDerivs2, declareTemps, setTemps, initParamDerivs, saveParamDerivs;
3830
        if (force.getNumGlobalParameters() > 0)
Peter Eastman's avatar
Peter Eastman committed
3831
            extraArgs << ", __global const float* globals";
Peter Eastman's avatar
Peter Eastman committed
3832
        pairEnergyUsesParam.resize(params->getBuffers().size(), false);
3833
3834
        for (int i = 0; i < (int) params->getBuffers().size(); i++) {
            const OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i];
3835
            string paramName = "params"+cl.intToString(i+1);
Peter Eastman's avatar
Peter Eastman committed
3836
3837
3838
3839
3840
3841
3842
3843
            if (n2EnergyStr.find(paramName+"1") != n2EnergyStr.npos || n2EnergyStr.find(paramName+"2") != n2EnergyStr.npos) {
                extraArgs << ", __global const " << buffer.getType() << "* restrict global_" << paramName << ", __local " << buffer.getType() << "* restrict local_" << paramName;
                loadLocal1 << "local_" << paramName << "[localAtomIndex] = " << paramName << "1;\n";
                loadLocal2 << "local_" << paramName << "[localAtomIndex] = global_" << paramName << "[j];\n";
                load1 << buffer.getType() << " " << paramName << "1 = global_" << paramName << "[atom1];\n";
                load2 << buffer.getType() << " " << paramName << "2 = local_" << paramName << "[atom2];\n";
                pairEnergyUsesParam[i] = true;
            }
3844
        }
Peter Eastman's avatar
Peter Eastman committed
3845
        pairEnergyUsesValue.resize(computedValues->getBuffers().size(), false);
3846
3847
        for (int i = 0; i < (int) computedValues->getBuffers().size(); i++) {
            const OpenCLNonbondedUtilities::ParameterInfo& buffer = computedValues->getBuffers()[i];
3848
            string valueName = "values"+cl.intToString(i+1);
Peter Eastman's avatar
Peter Eastman committed
3849
3850
3851
3852
3853
3854
3855
3856
            if (n2EnergyStr.find(valueName+"1") != n2EnergyStr.npos || n2EnergyStr.find(valueName+"2") != n2EnergyStr.npos) {
                extraArgs << ", __global const " << buffer.getType() << "* restrict global_" << valueName << ", __local " << buffer.getType() << "* restrict local_" << valueName;
                loadLocal1 << "local_" << valueName << "[localAtomIndex] = " << valueName << "1;\n";
                loadLocal2 << "local_" << valueName << "[localAtomIndex] = global_" << valueName << "[j];\n";
                load1 << buffer.getType() << " " << valueName << "1 = global_" << valueName << "[atom1];\n";
                load2 << buffer.getType() << " " << valueName << "2 = local_" << valueName << "[atom2];\n";
                pairEnergyUsesValue[i] = true;
            }
3857
        }
3858
        if (useLong) {
3859
            extraArgs << ", __global long* restrict derivBuffers";
3860
            for (int i = 0; i < force.getNumComputedValues(); i++) {
3861
                string index = cl.intToString(i+1);
3862
                extraArgs << ", __local real* restrict local_deriv" << index;
3863
                clearLocal << "local_deriv" << index << "[localAtomIndex] = 0.0f;\n";
3864
3865
                declare1 << "real deriv" << index << "_1 = 0;\n";
                load2 << "real deriv" << index << "_2 = 0;\n";
3866
3867
3868
                recordDeriv << "local_deriv" << index << "[atom2] += deriv" << index << "_2;\n";
                storeDerivs1 << "STORE_DERIVATIVE_1(" << index << ")\n";
                storeDerivs2 << "STORE_DERIVATIVE_2(" << index << ")\n";
3869
                declareTemps << "__local real tempDerivBuffer" << index << "[64];\n";
3870
3871
3872
3873
3874
3875
                setTemps << "tempDerivBuffer" << index << "[get_local_id(0)] = deriv" << index << "_1;\n";
            }
        }
        else {
            for (int i = 0; i < (int) energyDerivs->getBuffers().size(); i++) {
                const OpenCLNonbondedUtilities::ParameterInfo& buffer = energyDerivs->getBuffers()[i];
3876
                string index = cl.intToString(i+1);
3877
                extraArgs << ", __global " << buffer.getType() << "* restrict derivBuffers" << index << ", __local " << buffer.getType() << "* restrict local_deriv" << index;
3878
3879
3880
3881
3882
3883
3884
3885
3886
                clearLocal << "local_deriv" << index << "[localAtomIndex] = 0.0f;\n";
                declare1 << buffer.getType() << " deriv" << index << "_1 = 0.0f;\n";
                load2 << buffer.getType() << " deriv" << index << "_2 = 0.0f;\n";
                recordDeriv << "local_deriv" << index << "[atom2] += deriv" << index << "_2;\n";
                storeDerivs1 << "STORE_DERIVATIVE_1(" << index << ")\n";
                storeDerivs2 << "STORE_DERIVATIVE_2(" << index << ")\n";
                declareTemps << "__local " << buffer.getType() << " tempDerivBuffer" << index << "[64];\n";
                setTemps << "tempDerivBuffer" << index << "[get_local_id(0)] = deriv" << index << "_1;\n";
            }
3887
        }
3888
3889
3890
3891
3892
3893
3894
3895
3896
3897
3898
        if (needEnergyParamDerivs) {
            extraArgs << ", __global mixed* restrict energyParamDerivs";
            const vector<string>& allParamDerivNames = cl.getEnergyParamDerivNames();
            int numDerivs = allParamDerivNames.size();
            for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
                initParamDerivs << "mixed energyParamDeriv" << i << " = 0;\n";
                for (int index = 0; index < numDerivs; index++)
                    if (allParamDerivNames[index] == force.getEnergyParameterDerivativeName(i))
                        saveParamDerivs << "energyParamDerivs[get_global_id(0)*" << numDerivs << "+" << index << "] += energyParamDeriv" << i << ";\n";
            }
        }
3899
3900
3901
        replacements["PARAMETER_ARGUMENTS"] = extraArgs.str()+tableArgs.str();
        replacements["LOAD_LOCAL_PARAMETERS_FROM_1"] = loadLocal1.str();
        replacements["LOAD_LOCAL_PARAMETERS_FROM_GLOBAL"] = loadLocal2.str();
3902
        replacements["CLEAR_LOCAL_DERIVATIVES"] = clearLocal.str();
3903
3904
        replacements["LOAD_ATOM1_PARAMETERS"] = load1.str();
        replacements["LOAD_ATOM2_PARAMETERS"] = load2.str();
3905
        replacements["DECLARE_ATOM1_DERIVATIVES"] = declare1.str();
3906
3907
3908
        replacements["RECORD_DERIVATIVE_2"] = recordDeriv.str();
        replacements["STORE_DERIVATIVES_1"] = storeDerivs1.str();
        replacements["STORE_DERIVATIVES_2"] = storeDerivs2.str();
3909
3910
        replacements["DECLARE_TEMP_BUFFERS"] = declareTemps.str();
        replacements["SET_TEMP_BUFFERS"] = setTemps.str();
3911
3912
        replacements["INIT_PARAM_DERIVS"] = initParamDerivs.str();
        replacements["SAVE_PARAM_DERIVS"] = saveParamDerivs.str();
3913
        if (useCutoff)
3914
            pairEnergyDefines["USE_CUTOFF"] = "1";
3915
        if (usePeriodic)
3916
            pairEnergyDefines["USE_PERIODIC"] = "1";
3917
        if (anyExclusions)
3918
3919
            pairEnergyDefines["USE_EXCLUSIONS"] = "1";
        pairEnergyDefines["FORCE_WORK_GROUP_SIZE"] = cl.intToString(cl.getNonbondedUtilities().getForceThreadBlockSize());
3920
        pairEnergyDefines["CUTOFF_SQUARED"] = cl.doubleToString(cutoff*cutoff);
3921
3922
3923
3924
        pairEnergyDefines["NUM_ATOMS"] = cl.intToString(cl.getNumAtoms());
        pairEnergyDefines["PADDED_NUM_ATOMS"] = cl.intToString(cl.getPaddedNumAtoms());
        pairEnergyDefines["NUM_BLOCKS"] = cl.intToString(cl.getNumAtomBlocks());
        pairEnergyDefines["TILE_SIZE"] = cl.intToString(OpenCLContext::TileSize);
3925
3926
3927
3928
        string file;
        if (deviceIsCpu)
            file = OpenCLKernelSources::customGBEnergyN2_cpu;
        else
3929
3930
            file = OpenCLKernelSources::customGBEnergyN2;
        pairEnergySrc = cl.replaceStrings(file, replacements);
3931
3932
3933
3934
    }
    {
        // Create the kernel to reduce the derivatives and calculate per-particle energy terms.

3935
        stringstream compute, extraArgs, reduce, initParamDerivs, saveParamDerivs;
3936
        if (force.getNumGlobalParameters() > 0)
Peter Eastman's avatar
Peter Eastman committed
3937
            extraArgs << ", __global const float* globals";
3938
3939
        for (int i = 0; i < (int) params->getBuffers().size(); i++) {
            const OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i];
3940
            string paramName = "params"+cl.intToString(i+1);
3941
            extraArgs << ", __global const " << buffer.getType() << "* restrict " << paramName;
3942
3943
3944
        }
        for (int i = 0; i < (int) computedValues->getBuffers().size(); i++) {
            const OpenCLNonbondedUtilities::ParameterInfo& buffer = computedValues->getBuffers()[i];
3945
            string valueName = "values"+cl.intToString(i+1);
3946
            extraArgs << ", __global const " << buffer.getType() << "* restrict " << valueName;
3947
        }
3948
3949
        for (int i = 0; i < (int) energyDerivs->getBuffers().size(); i++) {
            const OpenCLNonbondedUtilities::ParameterInfo& buffer = energyDerivs->getBuffers()[i];
3950
            string index = cl.intToString(i+1);
3951
            extraArgs << ", __global " << buffer.getType() << "* restrict derivBuffers" << index;
3952
3953
            compute << buffer.getType() << " deriv" << index << " = derivBuffers" << index << "[index];\n";
        }
3954
3955
3956
3957
3958
        for (int i = 0; i < (int) energyDerivChain->getBuffers().size(); i++) {
            const OpenCLNonbondedUtilities::ParameterInfo& buffer = energyDerivChain->getBuffers()[i];
            string index = cl.intToString(i+1);
            extraArgs << ", __global " << buffer.getType() << "* restrict derivChain" << index;
        }
3959
        if (useLong) {
3960
            extraArgs << ", __global const long* restrict derivBuffersIn";
3961
3962
            for (int i = 0; i < energyDerivs->getNumParameters(); ++i)
                reduce << "derivBuffers" << energyDerivs->getParameterSuffix(i, "[index]") <<
3963
                        " = (1.0f/0x100000000)*derivBuffersIn[index+PADDED_NUM_ATOMS*" << cl.intToString(i) << "];\n";
3964
3965
3966
        }
        else {
            for (int i = 0; i < (int) energyDerivs->getBuffers().size(); i++)
3967
                reduce << "REDUCE_VALUE(derivBuffers" << cl.intToString(i+1) << ", " << energyDerivs->getBuffers()[i].getType() << ")\n";
3968
        }
3969
3970
3971
3972
3973
3974
3975
3976
3977
3978
3979
        if (needEnergyParamDerivs) {
            extraArgs << ", __global mixed* restrict energyParamDerivs";
            const vector<string>& allParamDerivNames = cl.getEnergyParamDerivNames();
            int numDerivs = allParamDerivNames.size();
            for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
                initParamDerivs << "mixed energyParamDeriv" << i << " = 0;\n";
                for (int index = 0; index < numDerivs; index++)
                    if (allParamDerivNames[index] == force.getEnergyParameterDerivativeName(i))
                        saveParamDerivs << "energyParamDerivs[get_global_id(0)*" << numDerivs << "+" << index << "] += energyParamDeriv" << i << ";\n";
            }
        }
Peter Eastman's avatar
Peter Eastman committed
3980
3981
3982
        
        // Compute the various expressions.
        
3983
        map<string, string> variables;
3984
3985
3986
        variables["x"] = "pos.x";
        variables["y"] = "pos.y";
        variables["z"] = "pos.z";
3987
3988
3989
        for (int i = 0; i < force.getNumPerParticleParameters(); i++)
            variables[force.getPerParticleParameterName(i)] = "params"+params->getParameterSuffix(i, "[index]");
        for (int i = 0; i < force.getNumGlobalParameters(); i++)
3990
            variables[force.getGlobalParameterName(i)] = "globals["+cl.intToString(i)+"]";
3991
3992
        for (int i = 0; i < force.getNumComputedValues(); i++)
            variables[computedValueNames[i]] = "values"+computedValues->getParameterSuffix(i, "[index]");
Peter Eastman's avatar
Peter Eastman committed
3993
        map<string, Lepton::ParsedExpression> expressions;
3994
3995
3996
3997
3998
3999
        for (int i = 0; i < force.getNumEnergyTerms(); i++) {
            string expression;
            CustomGBForce::ComputationType type;
            force.getEnergyTermParameters(i, expression, type);
            if (type != CustomGBForce::SingleParticle)
                continue;
4000
            Lepton::ParsedExpression parsed = Lepton::Parser::parse(expression, functions).optimize();
4001
            expressions["/*"+cl.intToString(i+1)+"*/ energy += "] = parsed;
4002
            for (int j = 0; j < force.getNumComputedValues(); j++)
4003
                expressions["/*"+cl.intToString(i+1)+"*/ deriv"+energyDerivs->getParameterSuffix(j)+" += "] = energyDerivExpressions[i][j];
4004
4005
4006
4007
            Lepton::ParsedExpression gradx = parsed.differentiate("x").optimize();
            Lepton::ParsedExpression grady = parsed.differentiate("y").optimize();
            Lepton::ParsedExpression gradz = parsed.differentiate("z").optimize();
            if (!isZeroExpression(gradx))
4008
                expressions["/*"+cl.intToString(i+1)+"*/ force.x -= "] = gradx;
4009
            if (!isZeroExpression(grady))
4010
                expressions["/*"+cl.intToString(i+1)+"*/ force.y -= "] = grady;
4011
            if (!isZeroExpression(gradz))
4012
                expressions["/*"+cl.intToString(i+1)+"*/ force.z -= "] = gradz;
4013
4014
            for (int j = 0; j < force.getNumEnergyParameterDerivatives(); j++)
                expressions["/*"+cl.intToString(i+1)+"*/ energyParamDeriv"+cl.intToString(j)+" += "] = energyParamDerivExpressions[i][j];
Peter Eastman's avatar
Peter Eastman committed
4015
4016
4017
        }
        for (int i = 1; i < force.getNumComputedValues(); i++)
            for (int j = 0; j < i; j++)
4018
                expressions["real dV"+cl.intToString(i)+"dV"+cl.intToString(j)+" = "] = valueDerivExpressions[i][j];
4019
        compute << cl.getExpressionUtilities().createExpressions(expressions, variables, functionList, functionDefinitions, "temp");
Peter Eastman's avatar
Peter Eastman committed
4020
4021
4022
        
        // Record values.
        
4023
4024
4025
4026
        for (int i = 0; i < (int) energyDerivs->getBuffers().size(); i++) {
            string index = cl.intToString(i+1);
            compute << "derivBuffers" << index << "[index] = deriv" << index << ";\n";
        }
Peter Eastman's avatar
Peter Eastman committed
4027
4028
        compute << "forceBuffers[index] = forceBuffers[index]+force;\n";
        for (int i = 1; i < force.getNumComputedValues(); i++) {
4029
            compute << "real totalDeriv"<<i<<" = dV"<<i<<"dV0";
Peter Eastman's avatar
Peter Eastman committed
4030
4031
4032
4033
            for (int j = 1; j < i; j++)
                compute << " + totalDeriv"<<j<<"*dV"<<i<<"dV"<<j;
            compute << ";\n";
            compute << "deriv"<<(i+1)<<" *= totalDeriv"<<i<<";\n";
4034
4035
        }
        for (int i = 0; i < (int) energyDerivs->getBuffers().size(); i++) {
4036
            string index = cl.intToString(i+1);
4037
            compute << "derivChain" << index << "[index] = deriv" << index << ";\n";
4038
4039
4040
        }
        map<string, string> replacements;
        replacements["PARAMETER_ARGUMENTS"] = extraArgs.str()+tableArgs.str();
4041
4042
        replacements["REDUCE_DERIVATIVES"] = reduce.str();
        replacements["COMPUTE_ENERGY"] = compute.str();
4043
4044
        replacements["INIT_PARAM_DERIVS"] = initParamDerivs.str();
        replacements["SAVE_PARAM_DERIVS"] = saveParamDerivs.str();
4045
        map<string, string> defines;
4046
4047
        defines["NUM_ATOMS"] = cl.intToString(cl.getNumAtoms());
        defines["PADDED_NUM_ATOMS"] = cl.intToString(cl.getPaddedNumAtoms());
4048
        cl::Program program = cl.createProgram(cl.replaceStrings(OpenCLKernelSources::customGBEnergyPerParticle, replacements), defines);
4049
        perParticleEnergyKernel = cl::Kernel(program, "computePerParticleEnergy");
4050
    }
4051
4052
4053
    if (needParameterGradient || needEnergyParamDerivs) {
        // Create the kernel to compute chain rule terms for computed values that depend explicitly on particle coordinates, and for
        // derivatives with respect to global parameters.
Peter Eastman's avatar
Peter Eastman committed
4054

4055
        stringstream compute, extraArgs, initParamDerivs, saveParamDerivs;
Peter Eastman's avatar
Peter Eastman committed
4056
        if (force.getNumGlobalParameters() > 0)
Peter Eastman's avatar
Peter Eastman committed
4057
            extraArgs << ", __global const float* globals";
Peter Eastman's avatar
Peter Eastman committed
4058
4059
        for (int i = 0; i < (int) params->getBuffers().size(); i++) {
            const OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i];
4060
            string paramName = "params"+cl.intToString(i+1);
4061
            extraArgs << ", __global const " << buffer.getType() << "* restrict " << paramName;
Peter Eastman's avatar
Peter Eastman committed
4062
4063
4064
        }
        for (int i = 0; i < (int) computedValues->getBuffers().size(); i++) {
            const OpenCLNonbondedUtilities::ParameterInfo& buffer = computedValues->getBuffers()[i];
4065
            string valueName = "values"+cl.intToString(i+1);
4066
            extraArgs << ", __global const " << buffer.getType() << "* restrict " << valueName;
Peter Eastman's avatar
Peter Eastman committed
4067
4068
4069
        }
        for (int i = 0; i < (int) energyDerivs->getBuffers().size(); i++) {
            const OpenCLNonbondedUtilities::ParameterInfo& buffer = energyDerivs->getBuffers()[i];
4070
            string index = cl.intToString(i+1);
4071
            extraArgs << ", __global " << buffer.getType() << "* restrict derivBuffers" << index;
Peter Eastman's avatar
Peter Eastman committed
4072
4073
            compute << buffer.getType() << " deriv" << index << " = derivBuffers" << index << "[index];\n";
        }
4074
4075
4076
4077
4078
4079
4080
4081
4082
4083
4084
4085
4086
        if (needEnergyParamDerivs) {
            extraArgs << ", __global mixed* restrict energyParamDerivs";
            const vector<string>& allParamDerivNames = cl.getEnergyParamDerivNames();
            int numDerivs = allParamDerivNames.size();
            for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
                for (int j = 0; j < dValuedParam[i]->getBuffers().size(); j++)
                    extraArgs << ", __global real* restrict dValuedParam_" << j << "_" << i;
                initParamDerivs << "mixed energyParamDeriv" << i << " = 0;\n";
                for (int index = 0; index < numDerivs; index++)
                    if (allParamDerivNames[index] == force.getEnergyParameterDerivativeName(i))
                        saveParamDerivs << "energyParamDerivs[get_global_id(0)*" << numDerivs << "+" << index << "] += energyParamDeriv" << i << ";\n";
            }
        }
Peter Eastman's avatar
Peter Eastman committed
4087
4088
4089
4090
4091
4092
4093
        map<string, string> variables;
        variables["x"] = "pos.x";
        variables["y"] = "pos.y";
        variables["z"] = "pos.z";
        for (int i = 0; i < force.getNumPerParticleParameters(); i++)
            variables[force.getPerParticleParameterName(i)] = "params"+params->getParameterSuffix(i, "[index]");
        for (int i = 0; i < force.getNumGlobalParameters(); i++)
4094
            variables[force.getGlobalParameterName(i)] = "globals["+cl.intToString(i)+"]";
Peter Eastman's avatar
Peter Eastman committed
4095
4096
        for (int i = 0; i < force.getNumComputedValues(); i++)
            variables[computedValueNames[i]] = "values"+computedValues->getParameterSuffix(i, "[index]");
4097
4098
4099
4100
4101
4102
4103
4104
4105
4106
4107
4108
        if (needParameterGradient) {
            for (int i = 1; i < force.getNumComputedValues(); i++) {
                string is = cl.intToString(i);
                compute << "real4 dV"<<is<<"dR = (real4) 0;\n";
                for (int j = 1; j < i; j++) {
                    if (!isZeroExpression(valueDerivExpressions[i][j])) {
                        map<string, Lepton::ParsedExpression> derivExpressions;
                        string js = cl.intToString(j);
                        derivExpressions["real dV"+is+"dV"+js+" = "] = valueDerivExpressions[i][j];
                        compute << cl.getExpressionUtilities().createExpressions(derivExpressions, variables, functionList, functionDefinitions, "temp_"+is+"_"+js);
                        compute << "dV"<<is<<"dR += dV"<<is<<"dV"<<js<<"*dV"<<js<<"dR;\n";
                    }
4109
                }
4110
4111
4112
4113
4114
4115
4116
4117
                map<string, Lepton::ParsedExpression> gradientExpressions;
                if (!isZeroExpression(valueGradientExpressions[i][0]))
                    gradientExpressions["dV"+is+"dR.x += "] = valueGradientExpressions[i][0];
                if (!isZeroExpression(valueGradientExpressions[i][1]))
                    gradientExpressions["dV"+is+"dR.y += "] = valueGradientExpressions[i][1];
                if (!isZeroExpression(valueGradientExpressions[i][2]))
                    gradientExpressions["dV"+is+"dR.z += "] = valueGradientExpressions[i][2];
                compute << cl.getExpressionUtilities().createExpressions(gradientExpressions, variables, functionList, functionDefinitions, "temp");
4118
            }
4119
4120
            for (int i = 1; i < force.getNumComputedValues(); i++)
                compute << "force -= deriv"<<energyDerivs->getParameterSuffix(i)<<"*dV"<<i<<"dR;\n";
Peter Eastman's avatar
Peter Eastman committed
4121
        }
4122
4123
4124
4125
        if (needEnergyParamDerivs)
            for (int i = 0; i < force.getNumComputedValues(); i++)
                for (int j = 0; j < dValuedParam.size(); j++)
                    compute << "energyParamDeriv"<<j<<" += deriv"<<energyDerivs->getParameterSuffix(i)<<"*dValuedParam_"<<i<<"_"<<j<<"[index];\n";
Peter Eastman's avatar
Peter Eastman committed
4126
4127
4128
        map<string, string> replacements;
        replacements["PARAMETER_ARGUMENTS"] = extraArgs.str()+tableArgs.str();
        replacements["COMPUTE_FORCES"] = compute.str();
4129
4130
        replacements["INIT_PARAM_DERIVS"] = initParamDerivs.str();
        replacements["SAVE_PARAM_DERIVS"] = saveParamDerivs.str();
Peter Eastman's avatar
Peter Eastman committed
4131
        map<string, string> defines;
4132
        defines["NUM_ATOMS"] = cl.intToString(cl.getNumAtoms());
Peter Eastman's avatar
Peter Eastman committed
4133
4134
4135
        cl::Program program = cl.createProgram(cl.replaceStrings(OpenCLKernelSources::customGBGradientChainRule, replacements), defines);
        gradientChainRuleKernel = cl::Kernel(program, "computeGradientChainRuleTerms");
    }
4136
    {
peastman's avatar
peastman committed
4137
        // Create the code to calculate chain rule terms as part of the default nonbonded kernel.
4138

4139
        vector<pair<ExpressionTreeNode, string> > globalVariables;
4140
4141
        for (int i = 0; i < force.getNumGlobalParameters(); i++) {
            const string& name = force.getGlobalParameterName(i);
4142
            string value = "globals["+cl.intToString(i)+"]";
4143
            globalVariables.push_back(makeVariable(name, prefix+value));
4144
        }
4145
        vector<pair<ExpressionTreeNode, string> > variables = globalVariables;
4146
        map<string, string> rename;
4147
4148
4149
4150
        ExpressionTreeNode rnode(new Operation::Variable("r"));
        variables.push_back(make_pair(rnode, "r"));
        variables.push_back(make_pair(ExpressionTreeNode(new Operation::Square(), rnode), "r2"));
        variables.push_back(make_pair(ExpressionTreeNode(new Operation::Reciprocal(), rnode), "invR"));
4151
4152
        for (int i = 0; i < force.getNumPerParticleParameters(); i++) {
            const string& name = force.getPerParticleParameterName(i);
4153
4154
            variables.push_back(makeVariable(name+"1", prefix+"params"+params->getParameterSuffix(i, "1")));
            variables.push_back(makeVariable(name+"2", prefix+"params"+params->getParameterSuffix(i, "2")));
Peter Eastman's avatar
Peter Eastman committed
4155
4156
            rename[name+"1"] = name+"2";
            rename[name+"2"] = name+"1";
4157
4158
4159
4160
        }
        map<string, Lepton::ParsedExpression> derivExpressions;
        stringstream chainSource;
        Lepton::ParsedExpression dVdR = Lepton::Parser::parse(computedValueExpressions[0], functions).differentiate("r").optimize();
4161
4162
        derivExpressions["real dV0dR1 = "] = dVdR;
        derivExpressions["real dV0dR2 = "] = dVdR.renameVariables(rename);
4163
        chainSource << cl.getExpressionUtilities().createExpressions(derivExpressions, variables, functionList, functionDefinitions, prefix+"temp0_");
Peter Eastman's avatar
Peter Eastman committed
4164
4165
4166
4167
4168
4169
4170
        if (needChainForValue[0]) {
            if (useExclusionsForValue)
                chainSource << "if (!isExcluded) {\n";
            chainSource << "tempForce -= dV0dR1*" << prefix << "dEdV" << energyDerivs->getParameterSuffix(0, "1") << ";\n";
            chainSource << "tempForce -= dV0dR2*" << prefix << "dEdV" << energyDerivs->getParameterSuffix(0, "2") << ";\n";
            if (useExclusionsForValue)
                chainSource << "}\n";
4171
        }
Peter Eastman's avatar
Peter Eastman committed
4172
4173
4174
4175
        for (int i = 1; i < force.getNumComputedValues(); i++) {
            if (needChainForValue[i]) {
                chainSource << "tempForce -= dV0dR1*" << prefix << "dEdV" << energyDerivs->getParameterSuffix(i, "1") << ";\n";
                chainSource << "tempForce -= dV0dR2*" << prefix << "dEdV" << energyDerivs->getParameterSuffix(i, "2") << ";\n";
4176
            }
4177
4178
        }
        map<string, string> replacements;
Peter Eastman's avatar
Peter Eastman committed
4179
4180
        string chainStr = chainSource.str();
        replacements["COMPUTE_FORCE"] = chainStr;
4181
        string source = cl.replaceStrings(OpenCLKernelSources::customGBChainRule, replacements);
4182
4183
        vector<OpenCLNonbondedUtilities::ParameterInfo> parameters;
        vector<OpenCLNonbondedUtilities::ParameterInfo> arguments;
4184
4185
        for (int i = 0; i < (int) params->getBuffers().size(); i++) {
            const OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i];
4186
            string paramName = prefix+"params"+cl.intToString(i+1);
Peter Eastman's avatar
Peter Eastman committed
4187
4188
            if (chainStr.find(paramName+"1") != chainStr.npos || chainStr.find(paramName+"2") != chainStr.npos)
                parameters.push_back(OpenCLNonbondedUtilities::ParameterInfo(paramName, buffer.getComponentType(), buffer.getNumComponents(), buffer.getSize(), buffer.getMemory()));
4189
4190
4191
        }
        for (int i = 0; i < (int) computedValues->getBuffers().size(); i++) {
            const OpenCLNonbondedUtilities::ParameterInfo& buffer = computedValues->getBuffers()[i];
4192
            string paramName = prefix+"values"+cl.intToString(i+1);
Peter Eastman's avatar
Peter Eastman committed
4193
4194
            if (chainStr.find(paramName+"1") != chainStr.npos || chainStr.find(paramName+"2") != chainStr.npos)
                parameters.push_back(OpenCLNonbondedUtilities::ParameterInfo(paramName, buffer.getComponentType(), buffer.getNumComponents(), buffer.getSize(), buffer.getMemory()));
4195
        }
4196
        for (int i = 0; i < (int) energyDerivChain->getBuffers().size(); i++) {
Peter Eastman's avatar
Peter Eastman committed
4197
            if (needChainForValue[i]) { 
4198
                const OpenCLNonbondedUtilities::ParameterInfo& buffer = energyDerivChain->getBuffers()[i];
4199
                string paramName = prefix+"dEdV"+cl.intToString(i+1);
Peter Eastman's avatar
Peter Eastman committed
4200
4201
                parameters.push_back(OpenCLNonbondedUtilities::ParameterInfo(paramName, buffer.getComponentType(), buffer.getNumComponents(), buffer.getSize(), buffer.getMemory()));
            }
4202
        }
peastman's avatar
peastman committed
4203
4204
4205
        if (globals.isInitialized()) {
            globals.upload(globalParamValues);
            arguments.push_back(OpenCLNonbondedUtilities::ParameterInfo(prefix+"globals", "float", 1, sizeof(cl_float), globals.getDeviceBuffer()));
4206
        }
4207
        cl.getNonbondedUtilities().addInteraction(useCutoff, usePeriodic, force.getNumExclusions() > 0, cutoff, exclusionList, source, force.getForceGroup());
peastman's avatar
peastman committed
4208
4209
4210
4211
        for (auto param : parameters)
            cl.getNonbondedUtilities().addParameter(param);
        for (auto arg : arguments)
            cl.getNonbondedUtilities().addArgument(arg);
4212
    }
4213
4214
    info = new ForceInfo(cl.getNonbondedUtilities().getNumForceBuffers(), force);
    cl.addForce(info);
4215
    if (useLong)
peastman's avatar
peastman committed
4216
        cl.addAutoclearBuffer(longEnergyDerivs);
Peter Eastman's avatar
Peter Eastman committed
4217
    else {
peastman's avatar
peastman committed
4218
        for (auto& buffer : energyDerivs->getBuffers())
4219
            cl.addAutoclearBuffer(buffer.getMemory(), buffer.getSize()*energyDerivs->getNumObjects());
Peter Eastman's avatar
Peter Eastman committed
4220
    }
4221
4222
}

4223
double OpenCLCalcCustomGBForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
4224
    bool deviceIsCpu = (cl.getDevice().getInfo<CL_DEVICE_TYPE>() == CL_DEVICE_TYPE_CPU);
4225
    OpenCLNonbondedUtilities& nb = cl.getNonbondedUtilities();
4226
    int elementSize = (cl.getUseDoublePrecision() ? sizeof(cl_double) : sizeof(cl_float));
4227
4228
    if (!hasInitializedKernels) {
        hasInitializedKernels = true;
4229
4230
4231
4232
4233
4234
4235
4236
4237
4238
4239
4240
        
        // These two kernels can't be compiled in initialize(), because the nonbonded utilities object
        // has not yet been initialized then.

        {
            int numExclusionTiles = nb.getExclusionTiles().getSize();
            pairValueDefines["NUM_TILES_WITH_EXCLUSIONS"] = cl.intToString(numExclusionTiles);
            int numContexts = cl.getPlatformData().contexts.size();
            int startExclusionIndex = cl.getContextIndex()*numExclusionTiles/numContexts;
            int endExclusionIndex = (cl.getContextIndex()+1)*numExclusionTiles/numContexts;
            pairValueDefines["FIRST_EXCLUSION_TILE"] = cl.intToString(startExclusionIndex);
            pairValueDefines["LAST_EXCLUSION_TILE"] = cl.intToString(endExclusionIndex);
4241
            pairValueDefines["CUTOFF"] = cl.doubleToString(cutoff);
4242
4243
4244
4245
4246
4247
4248
4249
4250
4251
4252
4253
4254
            cl::Program program = cl.createProgram(pairValueSrc, pairValueDefines);
            pairValueKernel = cl::Kernel(program, "computeN2Value");
            pairValueSrc = "";
            pairValueDefines.clear();
        }
        {
            int numExclusionTiles = nb.getExclusionTiles().getSize();
            pairEnergyDefines["NUM_TILES_WITH_EXCLUSIONS"] = cl.intToString(numExclusionTiles);
            int numContexts = cl.getPlatformData().contexts.size();
            int startExclusionIndex = cl.getContextIndex()*numExclusionTiles/numContexts;
            int endExclusionIndex = (cl.getContextIndex()+1)*numExclusionTiles/numContexts;
            pairEnergyDefines["FIRST_EXCLUSION_TILE"] = cl.intToString(startExclusionIndex);
            pairEnergyDefines["LAST_EXCLUSION_TILE"] = cl.intToString(endExclusionIndex);
4255
            pairEnergyDefines["CUTOFF"] = cl.doubleToString(cutoff);
4256
4257
4258
4259
4260
4261
4262
4263
            cl::Program program = cl.createProgram(pairEnergySrc, pairEnergyDefines);
            pairEnergyKernel = cl::Kernel(program, "computeN2Energy");
            pairEnergySrc = "";
            pairEnergyDefines.clear();
        }

        // Set arguments for kernels.
        
4264
        maxTiles = (nb.getUseCutoff() ? nb.getInteractingTiles().getSize() : 0);
4265
        bool useLong = cl.getSupports64BitGlobalAtomics();
4266
        if (useLong) {
peastman's avatar
peastman committed
4267
4268
4269
            longValueBuffers.initialize<cl_long>(cl, cl.getPaddedNumAtoms(), "customGBLongValueBuffers");
            cl.addAutoclearBuffer(longValueBuffers);
            cl.clearBuffer(longValueBuffers);
4270
4271
        }
        else {
peastman's avatar
peastman committed
4272
4273
4274
            valueBuffers.initialize(cl, cl.getPaddedNumAtoms()*nb.getNumForceBuffers(), elementSize, "customGBValueBuffers");
            cl.addAutoclearBuffer(valueBuffers);
            cl.clearBuffer(valueBuffers);
4275
        }
4276
4277
        int index = 0;
        pairValueKernel.setArg<cl::Buffer>(index++, cl.getPosq().getDeviceBuffer());
4278
        pairValueKernel.setArg(index++, (deviceIsCpu ? OpenCLContext::TileSize : nb.getForceThreadBlockSize())*4*elementSize, NULL);
4279
        pairValueKernel.setArg<cl::Buffer>(index++, cl.getNonbondedUtilities().getExclusions().getDeviceBuffer());
4280
        pairValueKernel.setArg<cl::Buffer>(index++, cl.getNonbondedUtilities().getExclusionTiles().getDeviceBuffer());
peastman's avatar
peastman committed
4281
        pairValueKernel.setArg<cl::Buffer>(index++, useLong ? longValueBuffers.getDeviceBuffer() : valueBuffers.getDeviceBuffer());
4282
        pairValueKernel.setArg(index++, (deviceIsCpu ? OpenCLContext::TileSize : nb.getForceThreadBlockSize())*elementSize, NULL);
4283
4284
4285
        if (nb.getUseCutoff()) {
            pairValueKernel.setArg<cl::Buffer>(index++, nb.getInteractingTiles().getDeviceBuffer());
            pairValueKernel.setArg<cl::Buffer>(index++, nb.getInteractionCount().getDeviceBuffer());
4286
            index += 5; // Periodic box size arguments are set when the kernel is executed.
4287
            pairValueKernel.setArg<cl_uint>(index++, maxTiles);
4288
            pairValueKernel.setArg<cl::Buffer>(index++, nb.getBlockCenters().getDeviceBuffer());
4289
            pairValueKernel.setArg<cl::Buffer>(index++, nb.getBlockBoundingBoxes().getDeviceBuffer());
4290
            pairValueKernel.setArg<cl::Buffer>(index++, nb.getInteractingAtoms().getDeviceBuffer());
4291
        }
4292
4293
        else
            pairValueKernel.setArg<cl_uint>(index++, cl.getNumAtomBlocks()*(cl.getNumAtomBlocks()+1)/2);
peastman's avatar
peastman committed
4294
4295
        if (globals.isInitialized())
            pairValueKernel.setArg<cl::Buffer>(index++, globals.getDeviceBuffer());
4296
        for (int i = 0; i < (int) params->getBuffers().size(); i++) {
Peter Eastman's avatar
Peter Eastman committed
4297
4298
4299
4300
4301
            if (pairValueUsesParam[i]) {
                const OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i];
                pairValueKernel.setArg<cl::Memory>(index++, buffer.getMemory());
                pairValueKernel.setArg(index++, (deviceIsCpu ? OpenCLContext::TileSize : nb.getForceThreadBlockSize())*buffer.getSize(), NULL);
            }
4302
        }
peastman's avatar
peastman committed
4303
4304
4305
        for (auto& d : dValue0dParam) {
            pairValueKernel.setArg<cl::Buffer>(index++, d.getDeviceBuffer());
            pairValueKernel.setArg(index++, (deviceIsCpu ? OpenCLContext::TileSize : nb.getForceThreadBlockSize())*d.getElementSize(), NULL);
4306
        }
peastman's avatar
peastman committed
4307
4308
        for (auto& function : tabulatedFunctions)
            pairValueKernel.setArg<cl::Buffer>(index++, function.getDeviceBuffer());
4309
        index = 0;
4310
4311
        perParticleValueKernel.setArg<cl_int>(index++, cl.getPaddedNumAtoms());
        perParticleValueKernel.setArg<cl_int>(index++, nb.getNumForceBuffers());
4312
        perParticleValueKernel.setArg<cl::Buffer>(index++, cl.getPosq().getDeviceBuffer());
peastman's avatar
peastman committed
4313
4314
4315
        perParticleValueKernel.setArg<cl::Buffer>(index++, useLong ? longValueBuffers.getDeviceBuffer() : valueBuffers.getDeviceBuffer());
        if (globals.isInitialized())
            perParticleValueKernel.setArg<cl::Buffer>(index++, globals.getDeviceBuffer());
peastman's avatar
peastman committed
4316
4317
4318
4319
        for (auto& buffer : params->getBuffers())
            perParticleValueKernel.setArg<cl::Memory>(index++, buffer.getMemory());
        for (auto& buffer : computedValues->getBuffers())
            perParticleValueKernel.setArg<cl::Memory>(index++, buffer.getMemory());
4320
        for (int i = 0; i < dValuedParam.size(); i++) {
peastman's avatar
peastman committed
4321
            perParticleValueKernel.setArg<cl::Memory>(index++, dValue0dParam[i].getDeviceBuffer());
4322
4323
4324
            for (int j = 0; j < dValuedParam[i]->getBuffers().size(); j++)
                perParticleValueKernel.setArg<cl::Memory>(index++, dValuedParam[i]->getBuffers()[j].getMemory());
        }
peastman's avatar
peastman committed
4325
4326
        for (auto& function : tabulatedFunctions)
            perParticleValueKernel.setArg<cl::Buffer>(index++, function.getDeviceBuffer());
4327
        index = 0;
4328
        pairEnergyKernel.setArg<cl::Buffer>(index++, useLong ? cl.getLongForceBuffer().getDeviceBuffer() : cl.getForceBuffers().getDeviceBuffer());
4329
        pairEnergyKernel.setArg<cl::Buffer>(index++, cl.getEnergyBuffer().getDeviceBuffer());
4330
        pairEnergyKernel.setArg(index++, (deviceIsCpu ? OpenCLContext::TileSize : nb.getForceThreadBlockSize())*4*elementSize, NULL);
4331
        pairEnergyKernel.setArg<cl::Buffer>(index++, cl.getPosq().getDeviceBuffer());
4332
        pairEnergyKernel.setArg(index++, (deviceIsCpu ? OpenCLContext::TileSize : nb.getForceThreadBlockSize())*4*elementSize, NULL);
4333
        pairEnergyKernel.setArg<cl::Buffer>(index++, cl.getNonbondedUtilities().getExclusions().getDeviceBuffer());
4334
        pairEnergyKernel.setArg<cl::Buffer>(index++, cl.getNonbondedUtilities().getExclusionTiles().getDeviceBuffer());
4335
        index++; // Whether to include energy.
4336
4337
4338
        if (nb.getUseCutoff()) {
            pairEnergyKernel.setArg<cl::Buffer>(index++, nb.getInteractingTiles().getDeviceBuffer());
            pairEnergyKernel.setArg<cl::Buffer>(index++, nb.getInteractionCount().getDeviceBuffer());
4339
            index += 5; // Periodic box size arguments are set when the kernel is executed.
4340
            pairEnergyKernel.setArg<cl_uint>(index++, maxTiles);
4341
            pairEnergyKernel.setArg<cl::Buffer>(index++, nb.getBlockCenters().getDeviceBuffer());
4342
            pairEnergyKernel.setArg<cl::Buffer>(index++, nb.getBlockBoundingBoxes().getDeviceBuffer());
4343
            pairEnergyKernel.setArg<cl::Buffer>(index++, nb.getInteractingAtoms().getDeviceBuffer());
4344
        }
4345
4346
        else
            pairEnergyKernel.setArg<cl_uint>(index++, cl.getNumAtomBlocks()*(cl.getNumAtomBlocks()+1)/2);
peastman's avatar
peastman committed
4347
4348
        if (globals.isInitialized())
            pairEnergyKernel.setArg<cl::Buffer>(index++, globals.getDeviceBuffer());
4349
        for (int i = 0; i < (int) params->getBuffers().size(); i++) {
Peter Eastman's avatar
Peter Eastman committed
4350
4351
4352
4353
4354
            if (pairEnergyUsesParam[i]) {
                const OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i];
                pairEnergyKernel.setArg<cl::Memory>(index++, buffer.getMemory());
                pairEnergyKernel.setArg(index++, (deviceIsCpu ? OpenCLContext::TileSize : nb.getForceThreadBlockSize())*buffer.getSize(), NULL);
            }
4355
4356
        }
        for (int i = 0; i < (int) computedValues->getBuffers().size(); i++) {
Peter Eastman's avatar
Peter Eastman committed
4357
4358
4359
4360
4361
            if (pairEnergyUsesValue[i]) {
                const OpenCLNonbondedUtilities::ParameterInfo& buffer = computedValues->getBuffers()[i];
                pairEnergyKernel.setArg<cl::Memory>(index++, buffer.getMemory());
                pairEnergyKernel.setArg(index++, (deviceIsCpu ? OpenCLContext::TileSize : nb.getForceThreadBlockSize())*buffer.getSize(), NULL);
            }
4362
        }
4363
        if (useLong) {
peastman's avatar
peastman committed
4364
            pairEnergyKernel.setArg<cl::Memory>(index++, longEnergyDerivs.getDeviceBuffer());
4365
            for (int i = 0; i < numComputedValues; ++i)
4366
                pairEnergyKernel.setArg(index++, (deviceIsCpu ? OpenCLContext::TileSize : nb.getForceThreadBlockSize())*elementSize, NULL);
4367
4368
        }
        else {
peastman's avatar
peastman committed
4369
            for (auto& buffer : energyDerivs->getBuffers()) {
4370
4371
4372
                pairEnergyKernel.setArg<cl::Memory>(index++, buffer.getMemory());
                pairEnergyKernel.setArg(index++, (deviceIsCpu ? OpenCLContext::TileSize : nb.getForceThreadBlockSize())*buffer.getSize(), NULL);
            }
4373
        }
4374
4375
        if (needEnergyParamDerivs)
            pairEnergyKernel.setArg<cl::Memory>(index++, cl.getEnergyParamDerivBuffer().getDeviceBuffer());
peastman's avatar
peastman committed
4376
4377
        for (auto& function : tabulatedFunctions)
            pairEnergyKernel.setArg<cl::Buffer>(index++, function.getDeviceBuffer());
4378
4379
4380
        index = 0;
        perParticleEnergyKernel.setArg<cl_int>(index++, cl.getPaddedNumAtoms());
        perParticleEnergyKernel.setArg<cl_int>(index++, nb.getNumForceBuffers());
4381
        perParticleEnergyKernel.setArg<cl::Buffer>(index++, cl.getForceBuffers().getDeviceBuffer());
4382
        perParticleEnergyKernel.setArg<cl::Buffer>(index++, cl.getEnergyBuffer().getDeviceBuffer());
4383
        perParticleEnergyKernel.setArg<cl::Buffer>(index++, cl.getPosq().getDeviceBuffer());
peastman's avatar
peastman committed
4384
4385
        if (globals.isInitialized())
            perParticleEnergyKernel.setArg<cl::Buffer>(index++, globals.getDeviceBuffer());
peastman's avatar
peastman committed
4386
4387
4388
4389
4390
4391
4392
4393
        for (auto& buffer : params->getBuffers())
            perParticleEnergyKernel.setArg<cl::Memory>(index++, buffer.getMemory());
        for (auto& buffer : computedValues->getBuffers())
            perParticleEnergyKernel.setArg<cl::Memory>(index++, buffer.getMemory());
        for (auto& buffer : energyDerivs->getBuffers())
            perParticleEnergyKernel.setArg<cl::Memory>(index++, buffer.getMemory());
        for (auto& buffer : energyDerivChain->getBuffers())
            perParticleEnergyKernel.setArg<cl::Memory>(index++, buffer.getMemory());
4394
        if (useLong)
peastman's avatar
peastman committed
4395
            perParticleEnergyKernel.setArg<cl::Memory>(index++, longEnergyDerivs.getDeviceBuffer());
4396
4397
        if (needEnergyParamDerivs)
            perParticleEnergyKernel.setArg<cl::Memory>(index++, cl.getEnergyParamDerivBuffer().getDeviceBuffer());
peastman's avatar
peastman committed
4398
4399
        for (auto& function : tabulatedFunctions)
            perParticleEnergyKernel.setArg<cl::Buffer>(index++, function.getDeviceBuffer());
4400
        if (needParameterGradient || needEnergyParamDerivs) {
Peter Eastman's avatar
Peter Eastman committed
4401
4402
4403
            index = 0;
            gradientChainRuleKernel.setArg<cl::Buffer>(index++, cl.getForceBuffers().getDeviceBuffer());
            gradientChainRuleKernel.setArg<cl::Buffer>(index++, cl.getPosq().getDeviceBuffer());
peastman's avatar
peastman committed
4404
4405
            if (globals.isInitialized())
                gradientChainRuleKernel.setArg<cl::Buffer>(index++, globals.getDeviceBuffer());
peastman's avatar
peastman committed
4406
4407
4408
4409
4410
4411
            for (auto& buffer : params->getBuffers())
                gradientChainRuleKernel.setArg<cl::Memory>(index++, buffer.getMemory());
            for (auto& buffer : computedValues->getBuffers())
                gradientChainRuleKernel.setArg<cl::Memory>(index++, buffer.getMemory());
            for (auto& buffer : energyDerivs->getBuffers())
                gradientChainRuleKernel.setArg<cl::Memory>(index++, buffer.getMemory());
4412
4413
            if (needEnergyParamDerivs) {
                gradientChainRuleKernel.setArg<cl::Buffer>(index++, cl.getEnergyParamDerivBuffer().getDeviceBuffer());
peastman's avatar
peastman committed
4414
4415
4416
                for (auto d : dValuedParam)
                    for (auto& buffer : d->getBuffers())
                        gradientChainRuleKernel.setArg<cl::Memory>(index++, buffer.getMemory());
4417
            }
4418
4419
            for (auto& function : tabulatedFunctions)
                gradientChainRuleKernel.setArg<cl::Buffer>(index++, function.getDeviceBuffer());
Peter Eastman's avatar
Peter Eastman committed
4420
        }
4421
    }
peastman's avatar
peastman committed
4422
    if (globals.isInitialized()) {
4423
        bool changed = false;
4424
        for (int i = 0; i < (int) globalParamNames.size(); i++) {
4425
4426
4427
4428
4429
4430
            cl_float value = (cl_float) context.getParameter(globalParamNames[i]);
            if (value != globalParamValues[i])
                changed = true;
            globalParamValues[i] = value;
        }
        if (changed)
peastman's avatar
peastman committed
4431
            globals.upload(globalParamValues);
4432
    }
4433
    pairEnergyKernel.setArg<cl_int>(7, includeEnergy);
4434
    if (nb.getUseCutoff()) {
4435
        setPeriodicBoxArgs(cl, pairValueKernel, 8);
4436
        setPeriodicBoxArgs(cl, pairEnergyKernel, 10);
4437
4438
        if (maxTiles < nb.getInteractingTiles().getSize()) {
            maxTiles = nb.getInteractingTiles().getSize();
4439
            pairValueKernel.setArg<cl::Buffer>(6, nb.getInteractingTiles().getDeviceBuffer());
4440
4441
            pairValueKernel.setArg<cl_uint>(13, maxTiles);
            pairValueKernel.setArg<cl::Buffer>(16, nb.getInteractingAtoms().getDeviceBuffer());
4442
4443
4444
            pairEnergyKernel.setArg<cl::Buffer>(8, nb.getInteractingTiles().getDeviceBuffer());
            pairEnergyKernel.setArg<cl_uint>(15, maxTiles);
            pairEnergyKernel.setArg<cl::Buffer>(18, nb.getInteractingAtoms().getDeviceBuffer());
4445
        }
4446
    }
4447
    cl.executeKernel(pairValueKernel, nb.getNumForceThreadBlocks()*nb.getForceThreadBlockSize(), nb.getForceThreadBlockSize());
4448
    cl.executeKernel(perParticleValueKernel, cl.getPaddedNumAtoms());
4449
    cl.executeKernel(pairEnergyKernel, nb.getNumForceThreadBlocks()*nb.getForceThreadBlockSize(), nb.getForceThreadBlockSize());
4450
    cl.executeKernel(perParticleEnergyKernel, cl.getPaddedNumAtoms());
4451
    if (needParameterGradient || needEnergyParamDerivs)
Peter Eastman's avatar
Peter Eastman committed
4452
        cl.executeKernel(gradientChainRuleKernel, cl.getPaddedNumAtoms());
4453
4454
4455
    return 0.0;
}

4456
4457
4458
4459
4460
4461
4462
void OpenCLCalcCustomGBForceKernel::copyParametersToContext(ContextImpl& context, const CustomGBForce& force) {
    int numParticles = force.getNumParticles();
    if (numParticles != cl.getNumAtoms())
        throw OpenMMException("updateParametersInContext: The number of particles has changed");
    
    // Record the per-particle parameters.
    
4463
    vector<vector<cl_float> > paramVector(cl.getPaddedNumAtoms(), vector<cl_float>(force.getNumPerParticleParameters(), 0));
4464
4465
4466
4467
4468
4469
4470
4471
4472
4473
    vector<double> parameters;
    for (int i = 0; i < numParticles; i++) {
        force.getParticleParameters(i, parameters);
        for (int j = 0; j < (int) parameters.size(); j++)
            paramVector[i][j] = (cl_float) parameters[j];
    }
    params->setParameterValues(paramVector);
    
    // Mark that the current reordering may be invalid.
    
4474
    cl.invalidateMolecules(info);
4475
4476
}

4477
class OpenCLCalcCustomExternalForceKernel::ForceInfo : public OpenCLForceInfo {
4478
public:
4479
    ForceInfo(const CustomExternalForce& force, int numParticles) : OpenCLForceInfo(0), force(force), indices(numParticles, -1) {
4480
4481
4482
4483
4484
4485
4486
4487
4488
4489
4490
4491
4492
4493
4494
4495
4496
4497
4498
        vector<double> params;
        for (int i = 0; i < force.getNumParticles(); i++) {
            int particle;
            force.getParticleParameters(i, particle, params);
            indices[particle] = i;
        }
    }
    bool areParticlesIdentical(int particle1, int particle2) {
        particle1 = indices[particle1];
        particle2 = indices[particle2];
        if (particle1 == -1 && particle2 == -1)
            return true;
        if (particle1 == -1 || particle2 == -1)
            return false;
        int temp;
        vector<double> params1;
        vector<double> params2;
        force.getParticleParameters(particle1, temp, params1);
        force.getParticleParameters(particle2, temp, params2);
4499
        for (int i = 0; i < (int) params1.size(); i++)
4500
4501
4502
4503
4504
4505
4506
4507
4508
4509
4510
4511
4512
4513
4514
            if (params1[i] != params2[i])
                return false;
        return true;
    }
private:
    const CustomExternalForce& force;
    vector<int> indices;
};

OpenCLCalcCustomExternalForceKernel::~OpenCLCalcCustomExternalForceKernel() {
    if (params != NULL)
        delete params;
}

void OpenCLCalcCustomExternalForceKernel::initialize(const System& system, const CustomExternalForce& force) {
4515
4516
4517
4518
4519
4520
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*force.getNumParticles()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*force.getNumParticles()/numContexts;
    numParticles = endIndex-startIndex;
    if (numParticles == 0)
        return;
4521
    vector<vector<int> > atoms(numParticles, vector<int>(1));
4522
4523
    params = new OpenCLParameterSet(cl, force.getNumPerParticleParameters(), numParticles, "customExternalParams");
    vector<vector<cl_float> > paramVector(numParticles);
4524
4525
    for (int i = 0; i < numParticles; i++) {
        vector<double> parameters;
4526
        force.getParticleParameters(startIndex+i, atoms[i][0], parameters);
4527
        paramVector[i].resize(parameters.size());
4528
        for (int j = 0; j < (int) parameters.size(); j++)
4529
            paramVector[i][j] = (cl_float) parameters[j];
4530
    }
4531
    params->setParameterValues(paramVector);
4532
4533
    info = new ForceInfo(force, system.getNumParticles());
    cl.addForce(info);
4534
4535
4536
4537
4538
4539
4540
4541
4542

    // Record information for the expressions.

    globalParamNames.resize(force.getNumGlobalParameters());
    globalParamValues.resize(force.getNumGlobalParameters());
    for (int i = 0; i < force.getNumGlobalParameters(); i++) {
        globalParamNames[i] = force.getGlobalParameterName(i);
        globalParamValues[i] = (cl_float) force.getGlobalParameterDefaultValue(i);
    }
4543
4544
4545
    map<string, Lepton::CustomFunction*> customFunctions;
    customFunctions["periodicdistance"] = cl.getExpressionUtilities().getPeriodicDistancePlaceholder();
    Lepton::ParsedExpression energyExpression = Lepton::Parser::parse(force.getEnergyFunction(), customFunctions).optimize();
4546
4547
4548
4549
4550
    Lepton::ParsedExpression forceExpressionX = energyExpression.differentiate("x").optimize();
    Lepton::ParsedExpression forceExpressionY = energyExpression.differentiate("y").optimize();
    Lepton::ParsedExpression forceExpressionZ = energyExpression.differentiate("z").optimize();
    map<string, Lepton::ParsedExpression> expressions;
    expressions["energy += "] = energyExpression;
4551
4552
4553
    expressions["real dEdX = "] = forceExpressionX;
    expressions["real dEdY = "] = forceExpressionY;
    expressions["real dEdZ = "] = forceExpressionZ;
4554
4555
4556
4557

    // Create the kernels.

    map<string, string> variables;
4558
4559
4560
    variables["x"] = "pos1.x";
    variables["y"] = "pos1.y";
    variables["z"] = "pos1.z";
4561
4562
    for (int i = 0; i < force.getNumPerParticleParameters(); i++) {
        const string& name = force.getPerParticleParameterName(i);
4563
        variables[name] = "particleParams"+params->getParameterSuffix(i);
4564
    }
4565
    if (force.getNumGlobalParameters() > 0) {
peastman's avatar
peastman committed
4566
4567
4568
        globals.initialize<cl_float>(cl, force.getNumGlobalParameters(), "customExternalGlobals", CL_MEM_READ_ONLY);
        globals.upload(globalParamValues);
        string argName = cl.getBondedUtilities().addArgument(globals.getDeviceBuffer(), "float");
4569
4570
        for (int i = 0; i < force.getNumGlobalParameters(); i++) {
            const string& name = force.getGlobalParameterName(i);
4571
            string value = argName+"["+cl.intToString(i)+"]";
4572
4573
            variables[name] = value;
        }
4574
4575
    }
    stringstream compute;
4576
4577
    for (int i = 0; i < (int) params->getBuffers().size(); i++) {
        const OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i];
4578
4579
        string argName = cl.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType());
        compute<<buffer.getType()<<" particleParams"<<(i+1)<<" = "<<argName<<"[index];\n";
4580
    }
peastman's avatar
peastman committed
4581
4582
    vector<const TabulatedFunction*> functions;
    vector<pair<string, string> > functionNames;
4583
    compute << cl.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp");
4584
    map<string, string> replacements;
4585
    replacements["COMPUTE_FORCE"] = compute.str();
4586
    cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::customExternalForce, replacements), force.getForceGroup());
4587
4588
}

4589
double OpenCLCalcCustomExternalForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
peastman's avatar
peastman committed
4590
    if (globals.isInitialized()) {
4591
        bool changed = false;
4592
        for (int i = 0; i < (int) globalParamNames.size(); i++) {
4593
4594
4595
4596
4597
4598
            cl_float value = (cl_float) context.getParameter(globalParamNames[i]);
            if (value != globalParamValues[i])
                changed = true;
            globalParamValues[i] = value;
        }
        if (changed)
peastman's avatar
peastman committed
4599
            globals.upload(globalParamValues);
4600
4601
    }
    return 0.0;
4602
}
4603

4604
4605
4606
4607
4608
4609
void OpenCLCalcCustomExternalForceKernel::copyParametersToContext(ContextImpl& context, const CustomExternalForce& force) {
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*force.getNumParticles()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*force.getNumParticles()/numContexts;
    if (numParticles != endIndex-startIndex)
        throw OpenMMException("updateParametersInContext: The number of particles has changed");
4610
4611
    if (numParticles == 0)
        return;
4612
4613
4614
4615
4616
4617
4618
4619
4620
4621
4622
4623
4624
4625
4626
4627
    
    // Record the per-particle parameters.
    
    vector<vector<cl_float> > paramVector(numParticles);
    vector<double> parameters;
    for (int i = 0; i < numParticles; i++) {
        int particle;
        force.getParticleParameters(startIndex+i, particle, parameters);
        paramVector[i].resize(parameters.size());
        for (int j = 0; j < (int) parameters.size(); j++)
            paramVector[i][j] = (cl_float) parameters[j];
    }
    params->setParameterValues(paramVector);
    
    // Mark that the current reordering may be invalid.
    
4628
    cl.invalidateMolecules(info);
4629
4630
}

4631
class OpenCLCalcCustomHbondForceKernel::ForceInfo : public OpenCLForceInfo {
4632
public:
4633
    ForceInfo(int requiredBuffers, const CustomHbondForce& force) : OpenCLForceInfo(requiredBuffers), force(force) {
4634
4635
4636
4637
4638
4639
4640
    }
    bool areParticlesIdentical(int particle1, int particle2) {
        return true;
    }
    int getNumParticleGroups() {
        return force.getNumDonors()+force.getNumAcceptors()+force.getNumExclusions();
    }
Peter Eastman's avatar
Peter Eastman committed
4641
    void getParticlesInGroup(int index, vector<int>& particles) {
4642
4643
4644
4645
        int p1, p2, p3;
        vector<double> parameters;
        if (index < force.getNumDonors()) {
            force.getDonorParameters(index, p1, p2, p3, parameters);
4646
4647
4648
4649
4650
4651
            particles.clear();
            particles.push_back(p1);
            if (p2 > -1)
                particles.push_back(p2);
            if (p3 > -1)
                particles.push_back(p3);
4652
4653
4654
4655
4656
            return;
        }
        index -= force.getNumDonors();
        if (index < force.getNumAcceptors()) {
            force.getAcceptorParameters(index, p1, p2, p3, parameters);
4657
4658
4659
4660
4661
4662
            particles.clear();
            particles.push_back(p1);
            if (p2 > -1)
                particles.push_back(p2);
            if (p3 > -1)
                particles.push_back(p3);
4663
4664
4665
4666
4667
            return;
        }
        index -= force.getNumAcceptors();
        int donor, acceptor;
        force.getExclusionParticles(index, donor, acceptor);
4668
        particles.clear();
4669
        force.getDonorParameters(donor, p1, p2, p3, parameters);
4670
4671
4672
4673
4674
        particles.push_back(p1);
        if (p2 > -1)
            particles.push_back(p2);
        if (p3 > -1)
            particles.push_back(p3);
4675
        force.getAcceptorParameters(acceptor, p1, p2, p3, parameters);
4676
4677
4678
4679
4680
        particles.push_back(p1);
        if (p2 > -1)
            particles.push_back(p2);
        if (p3 > -1)
            particles.push_back(p3);
4681
4682
4683
4684
4685
4686
4687
4688
4689
4690
4691
4692
4693
4694
4695
4696
4697
4698
4699
4700
4701
4702
4703
4704
4705
4706
4707
4708
4709
4710
4711
4712
4713
    }
    bool areGroupsIdentical(int group1, int group2) {
        int p1, p2, p3;
        vector<double> params1, params2;
        if (group1 < force.getNumDonors() && group2 < force.getNumDonors()) {
            force.getDonorParameters(group1, p1, p2, p3, params1);
            force.getDonorParameters(group2, p1, p2, p3, params2);
            return (params1 == params2 && params1 == params2);
        }
        if (group1 < force.getNumDonors() || group2 < force.getNumDonors())
            return false;
        group1 -= force.getNumDonors();
        group2 -= force.getNumDonors();
        if (group1 < force.getNumAcceptors() && group2 < force.getNumAcceptors()) {
            force.getAcceptorParameters(group1, p1, p2, p3, params1);
            force.getAcceptorParameters(group2, p1, p2, p3, params2);
            return (params1 == params2 && params1 == params2);
        }
        if (group1 < force.getNumAcceptors() || group2 < force.getNumAcceptors())
            return false;
        return true;
    }
private:
    const CustomHbondForce& force;
};

OpenCLCalcCustomHbondForceKernel::~OpenCLCalcCustomHbondForceKernel() {
    if (donorParams != NULL)
        delete donorParams;
    if (acceptorParams != NULL)
        delete acceptorParams;
}

4714
4715
4716
4717
4718
4719
4720
4721
4722
4723
4724
4725
static void addDonorAndAcceptorCode(stringstream& computeDonor, stringstream& computeAcceptor, const string& value) {
    computeDonor << value;
    computeAcceptor << value;
}

static void applyDonorAndAcceptorForces(stringstream& applyToDonor, stringstream& applyToAcceptor, int atom, const string& value) {
    string forceNames[] = {"f1", "f2", "f3"};
    if (atom < 3)
        applyToAcceptor << forceNames[atom]<<".xyz += "<<value<<";\n";
    else
        applyToDonor << forceNames[atom-3]<<".xyz += "<<value<<";\n";
}
4726

4727
void OpenCLCalcCustomHbondForceKernel::initialize(const System& system, const CustomHbondForce& force) {
4728
4729
    // Record the lists of donors and acceptors, and the parameters for each one.

4730
4731
4732
4733
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*force.getNumDonors()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*force.getNumDonors()/numContexts;
    numDonors = endIndex-startIndex;
4734
    numAcceptors = force.getNumAcceptors();
4735
4736
    if (numDonors == 0 || numAcceptors == 0)
        return;
4737
    int numParticles = system.getNumParticles();
peastman's avatar
peastman committed
4738
4739
    donors.initialize<mm_int4>(cl, numDonors, "customHbondDonors");
    acceptors.initialize<mm_int4>(cl, numAcceptors, "customHbondAcceptors");
4740
4741
4742
    donorParams = new OpenCLParameterSet(cl, force.getNumPerDonorParameters(), numDonors, "customHbondDonorParameters");
    acceptorParams = new OpenCLParameterSet(cl, force.getNumPerAcceptorParameters(), numAcceptors, "customHbondAcceptorParameters");
    if (force.getNumGlobalParameters() > 0)
peastman's avatar
peastman committed
4743
        globals.initialize<cl_float>(cl, force.getNumGlobalParameters(), "customHbondGlobals", CL_MEM_READ_ONLY);
4744
4745
4746
4747
    vector<vector<cl_float> > donorParamVector(numDonors);
    vector<mm_int4> donorVector(numDonors);
    for (int i = 0; i < numDonors; i++) {
        vector<double> parameters;
4748
        force.getDonorParameters(startIndex+i, donorVector[i].x, donorVector[i].y, donorVector[i].z, parameters);
4749
4750
4751
4752
        donorParamVector[i].resize(parameters.size());
        for (int j = 0; j < (int) parameters.size(); j++)
            donorParamVector[i][j] = (cl_float) parameters[j];
    }
peastman's avatar
peastman committed
4753
    donors.upload(donorVector);
4754
4755
4756
4757
4758
4759
4760
4761
4762
4763
    donorParams->setParameterValues(donorParamVector);
    vector<vector<cl_float> > acceptorParamVector(numAcceptors);
    vector<mm_int4> acceptorVector(numAcceptors);
    for (int i = 0; i < numAcceptors; i++) {
        vector<double> parameters;
        force.getAcceptorParameters(i, acceptorVector[i].x, acceptorVector[i].y, acceptorVector[i].z, parameters);
        acceptorParamVector[i].resize(parameters.size());
        for (int j = 0; j < (int) parameters.size(); j++)
            acceptorParamVector[i][j] = (cl_float) parameters[j];
    }
peastman's avatar
peastman committed
4764
    acceptors.upload(acceptorVector);
4765
4766
    acceptorParams->setParameterValues(acceptorParamVector);

4767
    // Select an output buffer index for each donor and acceptor.
4768

peastman's avatar
peastman committed
4769
4770
    donorBufferIndices.initialize<mm_int4>(cl, numDonors, "customHbondDonorBuffers");
    acceptorBufferIndices.initialize<mm_int4>(cl, numAcceptors, "customHbondAcceptorBuffers");
4771
4772
    vector<mm_int4> donorBufferVector(numDonors);
    vector<mm_int4> acceptorBufferVector(numAcceptors);
4773
    vector<int> donorBufferCounter(numParticles, 0);
4774
    for (int i = 0; i < numDonors; i++)
4775
4776
4777
        donorBufferVector[i] = mm_int4(donorVector[i].x > -1 ? donorBufferCounter[donorVector[i].x]++ : 0,
                                       donorVector[i].y > -1 ? donorBufferCounter[donorVector[i].y]++ : 0,
                                       donorVector[i].z > -1 ? donorBufferCounter[donorVector[i].z]++ : 0, 0);
4778
    vector<int> acceptorBufferCounter(numParticles, 0);
4779
    for (int i = 0; i < numAcceptors; i++)
4780
4781
4782
        acceptorBufferVector[i] = mm_int4(acceptorVector[i].x > -1 ? acceptorBufferCounter[acceptorVector[i].x]++ : 0,
                                       acceptorVector[i].y > -1 ? acceptorBufferCounter[acceptorVector[i].y]++ : 0,
                                       acceptorVector[i].z > -1 ? acceptorBufferCounter[acceptorVector[i].z]++ : 0, 0);
peastman's avatar
peastman committed
4783
4784
    donorBufferIndices.upload(donorBufferVector);
    acceptorBufferIndices.upload(acceptorBufferVector);
4785
    int maxBuffers = 1;
peastman's avatar
peastman committed
4786
4787
4788
4789
    for (int i : donorBufferCounter)
        maxBuffers = max(maxBuffers, i);
    for (int i : acceptorBufferCounter)
        maxBuffers = max(maxBuffers, i);
4790
4791
    info = new ForceInfo(maxBuffers, force);
    cl.addForce(info);
4792
4793
4794

    // Record exclusions.

4795
4796
    vector<mm_int4> donorExclusionVector(numDonors, mm_int4(-1, -1, -1, -1));
    vector<mm_int4> acceptorExclusionVector(numAcceptors, mm_int4(-1, -1, -1, -1));
4797
4798
4799
    for (int i = 0; i < force.getNumExclusions(); i++) {
        int donor, acceptor;
        force.getExclusionParticles(i, donor, acceptor);
4800
4801
4802
        if (donor < startIndex || donor >= endIndex)
            continue;
        donor -= startIndex;
4803
4804
4805
4806
4807
4808
4809
4810
4811
4812
4813
4814
4815
4816
4817
4818
4819
4820
4821
4822
        if (donorExclusionVector[donor].x == -1)
            donorExclusionVector[donor].x = acceptor;
        else if (donorExclusionVector[donor].y == -1)
            donorExclusionVector[donor].y = acceptor;
        else if (donorExclusionVector[donor].z == -1)
            donorExclusionVector[donor].z = acceptor;
        else if (donorExclusionVector[donor].w == -1)
            donorExclusionVector[donor].w = acceptor;
        else
            throw OpenMMException("CustomHbondForce: OpenCLPlatform does not support more than four exclusions per donor");
        if (acceptorExclusionVector[acceptor].x == -1)
            acceptorExclusionVector[acceptor].x = donor;
        else if (acceptorExclusionVector[acceptor].y == -1)
            acceptorExclusionVector[acceptor].y = donor;
        else if (acceptorExclusionVector[acceptor].z == -1)
            acceptorExclusionVector[acceptor].z = donor;
        else if (acceptorExclusionVector[acceptor].w == -1)
            acceptorExclusionVector[acceptor].w = donor;
        else
            throw OpenMMException("CustomHbondForce: OpenCLPlatform does not support more than four exclusions per acceptor");
4823
    }
peastman's avatar
peastman committed
4824
4825
4826
4827
    donorExclusions.initialize<mm_int4>(cl, numDonors, "customHbondDonorExclusions");
    acceptorExclusions.initialize<mm_int4>(cl, numAcceptors, "customHbondAcceptorExclusions");
    donorExclusions.upload(donorExclusionVector);
    acceptorExclusions.upload(acceptorExclusionVector);
4828
4829
4830
4831
4832

    // Record the tabulated functions.

    map<string, Lepton::CustomFunction*> functions;
    vector<pair<string, string> > functionDefinitions;
4833
    vector<const TabulatedFunction*> functionList;
4834
    stringstream tableArgs;
peastman's avatar
peastman committed
4835
    tabulatedFunctions.resize(force.getNumFunctions());
4836
    for (int i = 0; i < force.getNumFunctions(); i++) {
4837
4838
        functionList.push_back(&force.getTabulatedFunction(i));
        string name = force.getTabulatedFunctionName(i);
4839
        string arrayName = "table"+cl.intToString(i);
4840
        functionDefinitions.push_back(make_pair(name, arrayName));
4841
        functions[name] = cl.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
peastman's avatar
peastman committed
4842
        int width;
4843
        vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
peastman's avatar
peastman committed
4844
4845
        tabulatedFunctions[i].initialize<float>(cl, f.size(), "TabulatedFunction");
        tabulatedFunctions[i].upload(f);
peastman's avatar
peastman committed
4846
4847
4848
4849
        tableArgs << ", __global const float";
        if (width > 1)
            tableArgs << width;
        tableArgs << "* restrict " << arrayName;
4850
4851
    }

4852
    // Record information about parameters.
4853
4854
4855
4856
4857
4858
4859

    globalParamNames.resize(force.getNumGlobalParameters());
    globalParamValues.resize(force.getNumGlobalParameters());
    for (int i = 0; i < force.getNumGlobalParameters(); i++) {
        globalParamNames[i] = force.getGlobalParameterName(i);
        globalParamValues[i] = (cl_float) force.getGlobalParameterDefaultValue(i);
    }
peastman's avatar
peastman committed
4860
4861
    if (globals.isInitialized())
        globals.upload(globalParamValues);
4862
4863
4864
4865
4866
4867
4868
4869
4870
4871
4872
    map<string, string> variables;
    for (int i = 0; i < force.getNumPerDonorParameters(); i++) {
        const string& name = force.getPerDonorParameterName(i);
        variables[name] = "donorParams"+donorParams->getParameterSuffix(i);
    }
    for (int i = 0; i < force.getNumPerAcceptorParameters(); i++) {
        const string& name = force.getPerAcceptorParameterName(i);
        variables[name] = "acceptorParams"+acceptorParams->getParameterSuffix(i);
    }
    for (int i = 0; i < force.getNumGlobalParameters(); i++) {
        const string& name = force.getGlobalParameterName(i);
4873
        variables[name] = "globals["+cl.intToString(i)+"]";
4874
    }
4875
4876
4877
4878
4879
4880
4881
4882
4883
4884
4885
4886
4887
4888
4889

    // Now to generate the kernel.  First, it needs to calculate all distances, angles,
    // and dihedrals the expression depends on.

    map<string, vector<int> > distances;
    map<string, vector<int> > angles;
    map<string, vector<int> > dihedrals;
    Lepton::ParsedExpression energyExpression = CustomHbondForceImpl::prepareExpression(force, functions, distances, angles, dihedrals);
    map<string, Lepton::ParsedExpression> forceExpressions;
    set<string> computedDeltas;
    computedDeltas.insert("D1A1");
    string atomNames[] = {"A1", "A2", "A3", "D1", "D2", "D3"};
    string atomNamesLower[] = {"a1", "a2", "a3", "d1", "d2", "d3"};
    stringstream computeDonor, computeAcceptor, extraArgs;
    int index = 0;
peastman's avatar
peastman committed
4890
4891
    for (auto& distance : distances) {
        const vector<int>& atoms = distance.second;
4892
4893
        string deltaName = atomNames[atoms[0]]+atomNames[atoms[1]];
        if (computedDeltas.count(deltaName) == 0) {
Peter Eastman's avatar
Peter Eastman committed
4894
            addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 delta"+deltaName+" = delta("+atomNamesLower[atoms[0]]+", "+atomNamesLower[atoms[1]]+", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);\n");
4895
4896
            computedDeltas.insert(deltaName);
        }
4897
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real r_"+deltaName+" = SQRT(delta"+deltaName+".w);\n");
peastman's avatar
peastman committed
4898
4899
4900
        variables[distance.first] = "r_"+deltaName;
        forceExpressions["real dEdDistance"+cl.intToString(index)+" = "] = energyExpression.differentiate(distance.first).optimize();
        index++;
4901
4902
    }
    index = 0;
peastman's avatar
peastman committed
4903
4904
    for (auto& angle : angles) {
        const vector<int>& atoms = angle.second;
4905
4906
4907
4908
        string deltaName1 = atomNames[atoms[1]]+atomNames[atoms[0]];
        string deltaName2 = atomNames[atoms[1]]+atomNames[atoms[2]];
        string angleName = "angle_"+atomNames[atoms[0]]+atomNames[atoms[1]]+atomNames[atoms[2]];
        if (computedDeltas.count(deltaName1) == 0) {
Peter Eastman's avatar
Peter Eastman committed
4909
            addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 delta"+deltaName1+" = delta("+atomNamesLower[atoms[1]]+", "+atomNamesLower[atoms[0]]+", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);\n");
4910
4911
4912
            computedDeltas.insert(deltaName1);
        }
        if (computedDeltas.count(deltaName2) == 0) {
Peter Eastman's avatar
Peter Eastman committed
4913
            addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 delta"+deltaName2+" = delta("+atomNamesLower[atoms[1]]+", "+atomNamesLower[atoms[2]]+", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);\n");
4914
4915
            computedDeltas.insert(deltaName2);
        }
4916
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real "+angleName+" = computeAngle(delta"+deltaName1+", delta"+deltaName2+");\n");
peastman's avatar
peastman committed
4917
4918
4919
        variables[angle.first] = angleName;
        forceExpressions["real dEdAngle"+cl.intToString(index)+" = "] = energyExpression.differentiate(angle.first).optimize();
        index++;
4920
4921
    }
    index = 0;
peastman's avatar
peastman committed
4922
4923
    for (auto& dihedral : dihedrals) {
        const vector<int>& atoms = dihedral.second;
4924
4925
4926
4927
4928
4929
4930
        string deltaName1 = atomNames[atoms[0]]+atomNames[atoms[1]];
        string deltaName2 = atomNames[atoms[2]]+atomNames[atoms[1]];
        string deltaName3 = atomNames[atoms[2]]+atomNames[atoms[3]];
        string crossName1 = "cross_"+deltaName1+"_"+deltaName2;
        string crossName2 = "cross_"+deltaName2+"_"+deltaName3;
        string dihedralName = "dihedral_"+atomNames[atoms[0]]+atomNames[atoms[1]]+atomNames[atoms[2]]+atomNames[atoms[3]];
        if (computedDeltas.count(deltaName1) == 0) {
Peter Eastman's avatar
Peter Eastman committed
4931
            addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 delta"+deltaName1+" = delta("+atomNamesLower[atoms[0]]+", "+atomNamesLower[atoms[1]]+", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);\n");
4932
4933
4934
            computedDeltas.insert(deltaName1);
        }
        if (computedDeltas.count(deltaName2) == 0) {
Peter Eastman's avatar
Peter Eastman committed
4935
            addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 delta"+deltaName2+" = delta("+atomNamesLower[atoms[2]]+", "+atomNamesLower[atoms[1]]+", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);\n");
4936
4937
4938
            computedDeltas.insert(deltaName2);
        }
        if (computedDeltas.count(deltaName3) == 0) {
Peter Eastman's avatar
Peter Eastman committed
4939
            addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 delta"+deltaName3+" = delta("+atomNamesLower[atoms[2]]+", "+atomNamesLower[atoms[3]]+", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);\n");
4940
4941
            computedDeltas.insert(deltaName3);
        }
4942
4943
4944
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 "+crossName1+" = computeCross(delta"+deltaName1+", delta"+deltaName2+");\n");
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 "+crossName2+" = computeCross(delta"+deltaName2+", delta"+deltaName3+");\n");
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real "+dihedralName+" = computeAngle("+crossName1+", "+crossName2+");\n");
4945
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, dihedralName+" *= (delta"+deltaName1+".x*"+crossName2+".x + delta"+deltaName1+".y*"+crossName2+".y + delta"+deltaName1+".z*"+crossName2+".z < 0 ? -1 : 1);\n");
peastman's avatar
peastman committed
4946
4947
4948
        variables[dihedral.first] = dihedralName;
        forceExpressions["real dEdDihedral"+cl.intToString(index)+" = "] = energyExpression.differentiate(dihedral.first).optimize();
        index++;
4949
4950
4951
4952
    }

    // Next it needs to load parameters from global memory.

4953
    if (force.getNumGlobalParameters() > 0)
4954
        extraArgs << ", __global const float* restrict globals";
4955
4956
    for (int i = 0; i < (int) donorParams->getBuffers().size(); i++) {
        const OpenCLNonbondedUtilities::ParameterInfo& buffer = donorParams->getBuffers()[i];
4957
        extraArgs << ", __global const "+buffer.getType()+"* restrict donor"+buffer.getName();
4958
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, buffer.getType()+" donorParams"+cl.intToString(i+1)+" = donor"+buffer.getName()+"[donorIndex];\n");
4959
4960
4961
    }
    for (int i = 0; i < (int) acceptorParams->getBuffers().size(); i++) {
        const OpenCLNonbondedUtilities::ParameterInfo& buffer = acceptorParams->getBuffers()[i];
4962
        extraArgs << ", __global const "+buffer.getType()+"* restrict acceptor"+buffer.getName();
4963
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, buffer.getType()+" acceptorParams"+cl.intToString(i+1)+" = acceptor"+buffer.getName()+"[acceptorIndex];\n");
4964
    }
4965
4966
4967

    // Now evaluate the expressions.

4968
    computeAcceptor << cl.getExpressionUtilities().createExpressions(forceExpressions, variables, functionList, functionDefinitions, "temp");
4969
    forceExpressions["energy += "] = energyExpression;
4970
    computeDonor << cl.getExpressionUtilities().createExpressions(forceExpressions, variables, functionList, functionDefinitions, "temp");
4971
4972
4973
4974

    // Finally, apply forces to atoms.

    index = 0;
peastman's avatar
peastman committed
4975
4976
    for (auto& distance : distances) {
        const vector<int>& atoms = distance.second;
4977
        string deltaName = atomNames[atoms[0]]+atomNames[atoms[1]];
4978
        string value = "(dEdDistance"+cl.intToString(index)+"/r_"+deltaName+")*delta"+deltaName+".xyz";
4979
4980
        applyDonorAndAcceptorForces(computeDonor, computeAcceptor, atoms[0], "-"+value);
        applyDonorAndAcceptorForces(computeDonor, computeAcceptor, atoms[1], value);
peastman's avatar
peastman committed
4981
        index++;
4982
4983
    }
    index = 0;
peastman's avatar
peastman committed
4984
4985
    for (auto& angle : angles) {
        const vector<int>& atoms = angle.second;
4986
4987
4988
        string deltaName1 = atomNames[atoms[1]]+atomNames[atoms[0]];
        string deltaName2 = atomNames[atoms[1]]+atomNames[atoms[2]];
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "{\n");
4989
4990
4991
4992
4993
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 crossProd = cross(delta"+deltaName2+", delta"+deltaName1+");\n");
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real lengthCross = max(length(crossProd), (real) 1e-6f);\n");
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 deltaCross0 = -cross(delta"+deltaName1+", crossProd)*dEdAngle"+cl.intToString(index)+"/(delta"+deltaName1+".w*lengthCross);\n");
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 deltaCross2 = cross(delta"+deltaName2+", crossProd)*dEdAngle"+cl.intToString(index)+"/(delta"+deltaName2+".w*lengthCross);\n");
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 deltaCross1 = -(deltaCross0+deltaCross2);\n");
4994
4995
4996
4997
        applyDonorAndAcceptorForces(computeDonor, computeAcceptor, atoms[0], "deltaCross0.xyz");
        applyDonorAndAcceptorForces(computeDonor, computeAcceptor, atoms[1], "deltaCross1.xyz");
        applyDonorAndAcceptorForces(computeDonor, computeAcceptor, atoms[2], "deltaCross2.xyz");
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "}\n");
peastman's avatar
peastman committed
4998
        index++;
4999
5000
    }
    index = 0;
peastman's avatar
peastman committed
5001
5002
    for (auto& dihedral : dihedrals) {
        const vector<int>& atoms = dihedral.second;
5003
5004
5005
5006
5007
5008
        string deltaName1 = atomNames[atoms[0]]+atomNames[atoms[1]];
        string deltaName2 = atomNames[atoms[2]]+atomNames[atoms[1]];
        string deltaName3 = atomNames[atoms[2]]+atomNames[atoms[3]];
        string crossName1 = "cross_"+deltaName1+"_"+deltaName2;
        string crossName2 = "cross_"+deltaName2+"_"+deltaName3;
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "{\n");
5009
5010
5011
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real r = SQRT(delta"+deltaName2+".w);\n");
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 ff;\n");
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "ff.x = (-dEdDihedral"+cl.intToString(index)+"*r)/"+crossName1+".w;\n");
5012
5013
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "ff.y = (delta"+deltaName1+".x*delta"+deltaName2+".x + delta"+deltaName1+".y*delta"+deltaName2+".y + delta"+deltaName1+".z*delta"+deltaName2+".z)/delta"+deltaName2+".w;\n");
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "ff.z = (delta"+deltaName3+".x*delta"+deltaName2+".x + delta"+deltaName3+".y*delta"+deltaName2+".y + delta"+deltaName3+".z*delta"+deltaName2+".z)/delta"+deltaName2+".w;\n");
5014
5015
5016
5017
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "ff.w = (dEdDihedral"+cl.intToString(index)+"*r)/"+crossName2+".w;\n");
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 internalF0 = ff.x*"+crossName1+";\n");
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 internalF3 = ff.w*"+crossName2+";\n");
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 s = ff.y*internalF0 - ff.z*internalF3;\n");
5018
5019
5020
5021
5022
        applyDonorAndAcceptorForces(computeDonor, computeAcceptor, atoms[0], "internalF0.xyz");
        applyDonorAndAcceptorForces(computeDonor, computeAcceptor, atoms[1], "s.xyz-internalF0.xyz");
        applyDonorAndAcceptorForces(computeDonor, computeAcceptor, atoms[2], "-s.xyz-internalF3.xyz");
        applyDonorAndAcceptorForces(computeDonor, computeAcceptor, atoms[3], "internalF3.xyz");
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "}\n");
peastman's avatar
peastman committed
5023
        index++;
5024
5025
5026
5027
    }

    // Generate the kernels.

5028
    map<string, string> replacements;
5029
5030
    replacements["COMPUTE_DONOR_FORCE"] = computeDonor.str();
    replacements["COMPUTE_ACCEPTOR_FORCE"] = computeAcceptor.str();
5031
5032
    replacements["PARAMETER_ARGUMENTS"] = extraArgs.str()+tableArgs.str();
    map<string, string> defines;
5033
5034
5035
5036
    defines["PADDED_NUM_ATOMS"] = cl.intToString(cl.getPaddedNumAtoms());
    defines["NUM_DONORS"] = cl.intToString(numDonors);
    defines["NUM_ACCEPTORS"] = cl.intToString(numAcceptors);
    defines["PI"] = cl.doubleToString(M_PI);
5037
5038
    if (force.getNonbondedMethod() != CustomHbondForce::NoCutoff) {
        defines["USE_CUTOFF"] = "1";
5039
        defines["CUTOFF_SQUARED"] = cl.doubleToString(force.getCutoffDistance()*force.getCutoffDistance());
5040
5041
5042
    }
    if (force.getNonbondedMethod() != CustomHbondForce::NoCutoff && force.getNonbondedMethod() != CustomHbondForce::CutoffNonPeriodic)
        defines["USE_PERIODIC"] = "1";
5043
5044
    if (force.getNumExclusions() > 0)
        defines["USE_EXCLUSIONS"] = "1";
5045
    cl::Program program = cl.createProgram(cl.replaceStrings(OpenCLKernelSources::customHbondForce, replacements), defines);
5046
5047
    donorKernel = cl::Kernel(program, "computeDonorForces");
    acceptorKernel = cl::Kernel(program, "computeAcceptorForces");
5048
5049
}

5050
double OpenCLCalcCustomHbondForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
5051
5052
    if (numDonors == 0 || numAcceptors == 0)
        return 0.0;
peastman's avatar
peastman committed
5053
    if (globals.isInitialized()) {
5054
5055
5056
5057
5058
5059
5060
5061
        bool changed = false;
        for (int i = 0; i < (int) globalParamNames.size(); i++) {
            cl_float value = (cl_float) context.getParameter(globalParamNames[i]);
            if (value != globalParamValues[i])
                changed = true;
            globalParamValues[i] = value;
        }
        if (changed)
peastman's avatar
peastman committed
5062
            globals.upload(globalParamValues);
5063
5064
5065
5066
    }
    if (!hasInitializedKernel) {
        hasInitializedKernel = true;
        int index = 0;
5067
5068
5069
        donorKernel.setArg<cl::Buffer>(index++, cl.getForceBuffers().getDeviceBuffer());
        donorKernel.setArg<cl::Buffer>(index++, cl.getEnergyBuffer().getDeviceBuffer());
        donorKernel.setArg<cl::Buffer>(index++, cl.getPosq().getDeviceBuffer());
peastman's avatar
peastman committed
5070
5071
5072
5073
        donorKernel.setArg<cl::Buffer>(index++, donorExclusions.getDeviceBuffer());
        donorKernel.setArg<cl::Buffer>(index++, donors.getDeviceBuffer());
        donorKernel.setArg<cl::Buffer>(index++, acceptors.getDeviceBuffer());
        donorKernel.setArg<cl::Buffer>(index++, donorBufferIndices.getDeviceBuffer());
5074
        donorKernel.setArg(index++, 3*OpenCLContext::ThreadBlockSize*sizeof(mm_float4), NULL);
5075
        index += 5; // Periodic box size arguments are set when the kernel is executed.
peastman's avatar
peastman committed
5076
5077
        if (globals.isInitialized())
            donorKernel.setArg<cl::Buffer>(index++, globals.getDeviceBuffer());
peastman's avatar
peastman committed
5078
        for (auto& buffer : donorParams->getBuffers())
5079
            donorKernel.setArg<cl::Memory>(index++, buffer.getMemory());
peastman's avatar
peastman committed
5080
        for (auto& buffer : acceptorParams->getBuffers())
5081
            donorKernel.setArg<cl::Memory>(index++, buffer.getMemory());
peastman's avatar
peastman committed
5082
5083
        for (auto& function : tabulatedFunctions)
            donorKernel.setArg<cl::Buffer>(index++, function.getDeviceBuffer());
5084
5085
5086
5087
        index = 0;
        acceptorKernel.setArg<cl::Buffer>(index++, cl.getForceBuffers().getDeviceBuffer());
        acceptorKernel.setArg<cl::Buffer>(index++, cl.getEnergyBuffer().getDeviceBuffer());
        acceptorKernel.setArg<cl::Buffer>(index++, cl.getPosq().getDeviceBuffer());
peastman's avatar
peastman committed
5088
5089
5090
5091
        acceptorKernel.setArg<cl::Buffer>(index++, acceptorExclusions.getDeviceBuffer());
        acceptorKernel.setArg<cl::Buffer>(index++, donors.getDeviceBuffer());
        acceptorKernel.setArg<cl::Buffer>(index++, acceptors.getDeviceBuffer());
        acceptorKernel.setArg<cl::Buffer>(index++, acceptorBufferIndices.getDeviceBuffer());
5092
        acceptorKernel.setArg(index++, 3*OpenCLContext::ThreadBlockSize*sizeof(mm_float4), NULL);
5093
        index += 5; // Periodic box size arguments are set when the kernel is executed.
peastman's avatar
peastman committed
5094
5095
        if (globals.isInitialized())
            acceptorKernel.setArg<cl::Buffer>(index++, globals.getDeviceBuffer());
peastman's avatar
peastman committed
5096
        for (auto& buffer : donorParams->getBuffers())
5097
            acceptorKernel.setArg<cl::Memory>(index++, buffer.getMemory());
peastman's avatar
peastman committed
5098
        for (auto& buffer : acceptorParams->getBuffers())
5099
            acceptorKernel.setArg<cl::Memory>(index++, buffer.getMemory());
peastman's avatar
peastman committed
5100
5101
        for (auto& function : tabulatedFunctions)
            acceptorKernel.setArg<cl::Buffer>(index++, function.getDeviceBuffer());
5102
    }
5103
    setPeriodicBoxArgs(cl, donorKernel, 8);
Peter Eastman's avatar
Peter Eastman committed
5104
    cl.executeKernel(donorKernel, max(numDonors, numAcceptors));
5105
    setPeriodicBoxArgs(cl, acceptorKernel, 8);
Peter Eastman's avatar
Peter Eastman committed
5106
    cl.executeKernel(acceptorKernel, max(numDonors, numAcceptors));
5107
5108
5109
    return 0.0;
}

5110
5111
5112
5113
5114
5115
5116
5117
5118
5119
5120
void OpenCLCalcCustomHbondForceKernel::copyParametersToContext(ContextImpl& context, const CustomHbondForce& force) {
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*force.getNumDonors()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*force.getNumDonors()/numContexts;
    if (numDonors != endIndex-startIndex)
        throw OpenMMException("updateParametersInContext: The number of donors has changed");
    if (numAcceptors != force.getNumAcceptors())
        throw OpenMMException("updateParametersInContext: The number of acceptors has changed");
    
    // Record the per-donor parameters.
    
5121
5122
5123
5124
5125
5126
5127
5128
5129
5130
5131
    if (numDonors > 0) {
        vector<vector<cl_float> > donorParamVector(numDonors);
        vector<double> parameters;
        for (int i = 0; i < numDonors; i++) {
            int d1, d2, d3;
            force.getDonorParameters(startIndex+i, d1, d2, d3, parameters);
            donorParamVector[i].resize(parameters.size());
            for (int j = 0; j < (int) parameters.size(); j++)
                donorParamVector[i][j] = (cl_float) parameters[j];
        }
        donorParams->setParameterValues(donorParamVector);
5132
5133
5134
5135
    }
    
    // Record the per-acceptor parameters.
    
5136
5137
5138
5139
5140
5141
5142
5143
5144
5145
5146
    if (numAcceptors > 0) {
        vector<vector<cl_float> > acceptorParamVector(numAcceptors);
        vector<double> parameters;
        for (int i = 0; i < numAcceptors; i++) {
            int a1, a2, a3;
            force.getAcceptorParameters(i, a1, a2, a3, parameters);
            acceptorParamVector[i].resize(parameters.size());
            for (int j = 0; j < (int) parameters.size(); j++)
                acceptorParamVector[i][j] = (cl_float) parameters[j];
        }
        acceptorParams->setParameterValues(acceptorParamVector);
5147
5148
5149
5150
    }
    
    // Mark that the current reordering may be invalid.
    
5151
    cl.invalidateMolecules(info);
5152
5153
}

5154
class OpenCLCalcCustomCentroidBondForceKernel::ForceInfo : public OpenCLForceInfo {
5155
public:
5156
    ForceInfo(const CustomCentroidBondForce& force) : OpenCLForceInfo(0), force(force) {
5157
5158
5159
5160
5161
5162
5163
5164
    }
    int getNumParticleGroups() {
        return force.getNumBonds();
    }
    void getParticlesInGroup(int index, vector<int>& particles) {
        vector<double> parameters;
        vector<int> groups;
        force.getBondParameters(index, groups, parameters);
peastman's avatar
peastman committed
5165
        for (int group : groups) {
5166
5167
            vector<int> groupParticles;
            vector<double> weights;
peastman's avatar
peastman committed
5168
            force.getGroupParameters(group, groupParticles, weights);
5169
5170
5171
5172
5173
5174
5175
5176
5177
5178
5179
5180
5181
5182
5183
5184
5185
5186
5187
5188
5189
5190
5191
5192
5193
5194
5195
5196
5197
5198
5199
5200
5201
5202
5203
5204
5205
5206
5207
            particles.insert(particles.end(), groupParticles.begin(), groupParticles.end());
        }
    }
    bool areGroupsIdentical(int group1, int group2) {
        vector<int> groups1, groups2;
        vector<double> parameters1, parameters2;
        force.getBondParameters(group1, groups1, parameters1);
        force.getBondParameters(group2, groups2, parameters2);
        for (int i = 0; i < (int) parameters1.size(); i++)
            if (parameters1[i] != parameters2[i])
                return false;
        for (int i = 0; i < groups1.size(); i++) {
            vector<int> groupParticles;
            vector<double> weights1, weights2;
            force.getGroupParameters(groups1[i], groupParticles, weights1);
            force.getGroupParameters(groups2[i], groupParticles, weights2);
            if (weights1.size() != weights2.size())
                return false;
            for (int j = 0; j < weights1.size(); j++)
                if (weights1[j] != weights2[j])
                    return false;
        }
        return true;
    }
private:
    const CustomCentroidBondForce& force;
};

OpenCLCalcCustomCentroidBondForceKernel::~OpenCLCalcCustomCentroidBondForceKernel() {
    if (params != NULL)
        delete params;
}

void OpenCLCalcCustomCentroidBondForceKernel::initialize(const System& system, const CustomCentroidBondForce& force) {
    numBonds = force.getNumBonds();
    if (numBonds == 0)
        return;
    if (!cl.getSupports64BitGlobalAtomics())
        throw OpenMMException("CustomCentroidBondForce requires a device that supports 64 bit atomic operations");
5208
5209
    info = new ForceInfo(force);
    cl.addForce(info);
5210
5211
5212
5213
5214
    
    // Record the groups.
    
    numGroups = force.getNumGroups();
    vector<cl_int> groupParticleVec;
peastman's avatar
peastman committed
5215
    vector<cl_double> groupWeightVec;
5216
5217
5218
5219
5220
5221
5222
5223
5224
5225
5226
    vector<cl_int> groupOffsetVec;
    groupOffsetVec.push_back(0);
    for (int i = 0; i < numGroups; i++) {
        vector<int> particles;
        vector<double> weights;
        force.getGroupParameters(i, particles, weights);
        groupParticleVec.insert(groupParticleVec.end(), particles.begin(), particles.end());
        groupOffsetVec.push_back(groupParticleVec.size());
    }
    vector<vector<double> > normalizedWeights;
    CustomCentroidBondForceImpl::computeNormalizedWeights(force, system, normalizedWeights);
peastman's avatar
peastman committed
5227
5228
    for (int i = 0; i < numGroups; i++)
        groupWeightVec.insert(groupWeightVec.end(), normalizedWeights[i].begin(), normalizedWeights[i].end());
peastman's avatar
peastman committed
5229
5230
    groupParticles.initialize<int>(cl, groupParticleVec.size(), "groupParticles");
    groupParticles.upload(groupParticleVec);
5231
    if (cl.getUseDoublePrecision()) {
peastman's avatar
peastman committed
5232
5233
        groupWeights.initialize<double>(cl, groupParticleVec.size(), "groupWeights");
        centerPositions.initialize<mm_double4>(cl, numGroups, "centerPositions");
5234
5235
    }
    else {
peastman's avatar
peastman committed
5236
5237
5238
        groupWeights.initialize<float>(cl, groupParticleVec.size(), "groupWeights");
        centerPositions.initialize<mm_float4>(cl, numGroups, "centerPositions");
    }
peastman's avatar
peastman committed
5239
    groupWeights.upload(groupWeightVec, true, true);
peastman's avatar
peastman committed
5240
5241
5242
5243
    groupOffsets.initialize<int>(cl, groupOffsetVec.size(), "groupOffsets");
    groupOffsets.upload(groupOffsetVec);
    groupForces.initialize<long long>(cl, numGroups*3, "groupForces");
    cl.addAutoclearBuffer(groupForces);
5244
5245
5246
5247
5248
5249
5250
5251
5252
5253
5254
5255
5256
5257
5258
5259
5260
5261
    
    // Record the bonds.
    
    int groupsPerBond = force.getNumGroupsPerBond();
    vector<cl_int> bondGroupVec(numBonds*groupsPerBond);
    params = new OpenCLParameterSet(cl, force.getNumPerBondParameters(), numBonds, "customCentroidBondParams");
    vector<vector<float> > paramVector(numBonds);
    for (int i = 0; i < numBonds; i++) {
        vector<int> groups;
        vector<double> parameters;
        force.getBondParameters(i, groups, parameters);
        for (int j = 0; j < groups.size(); j++)
            bondGroupVec[i+j*numBonds] = groups[j];
        paramVector[i].resize(parameters.size());
        for (int j = 0; j < (int) parameters.size(); j++)
            paramVector[i][j] = (float) parameters[j];
    }
    params->setParameterValues(paramVector);
peastman's avatar
peastman committed
5262
5263
    bondGroups.initialize<int>(cl, bondGroupVec.size(), "bondGroups");
    bondGroups.upload(bondGroupVec);
5264
5265
5266
5267
5268
5269
5270

    // Record the tabulated functions.

    map<string, Lepton::CustomFunction*> functions;
    vector<pair<string, string> > functionDefinitions;
    vector<const TabulatedFunction*> functionList;
    stringstream extraArgs;
peastman's avatar
peastman committed
5271
    tabulatedFunctions.resize(force.getNumTabulatedFunctions());
5272
5273
5274
5275
5276
5277
5278
5279
    for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
        functionList.push_back(&force.getTabulatedFunction(i));
        string name = force.getTabulatedFunctionName(i);
        string arrayName = "table"+cl.intToString(i);
        functionDefinitions.push_back(make_pair(name, arrayName));
        functions[name] = cl.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
        int width;
        vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
peastman's avatar
peastman committed
5280
5281
        tabulatedFunctions[i].initialize<float>(cl, f.size(), "TabulatedFunction");
        tabulatedFunctions[i].upload(f);
5282
5283
5284
5285
5286
5287
5288
5289
5290
5291
5292
5293
5294
5295
5296
5297
5298
5299
5300
5301
5302
5303
5304
5305
5306
        extraArgs << ", __global const float";
        if (width > 1)
            extraArgs << width;
        extraArgs << "* restrict " << arrayName;
    }
    
    // Record information about parameters.

    globalParamNames.resize(force.getNumGlobalParameters());
    globalParamValues.resize(force.getNumGlobalParameters());
    for (int i = 0; i < force.getNumGlobalParameters(); i++) {
        globalParamNames[i] = force.getGlobalParameterName(i);
        globalParamValues[i] = (float) force.getGlobalParameterDefaultValue(i);
    }
    map<string, string> variables;
    for (int i = 0; i < groupsPerBond; i++) {
        string index = cl.intToString(i+1);
        variables["x"+index] = "pos"+index+".x";
        variables["y"+index] = "pos"+index+".y";
        variables["z"+index] = "pos"+index+".z";
    }
    for (int i = 0; i < force.getNumPerBondParameters(); i++) {
        const string& name = force.getPerBondParameterName(i);
        variables[name] = "bondParams"+params->getParameterSuffix(i);
    }
5307
5308
5309
    needEnergyParamDerivs = (force.getNumEnergyParameterDerivatives() > 0);
    if (needEnergyParamDerivs)
        extraArgs << ", __global mixed* restrict energyParamDerivs";
5310
    if (force.getNumGlobalParameters() > 0) {
peastman's avatar
peastman committed
5311
5312
        globals.initialize<float>(cl, force.getNumGlobalParameters(), "customCentroidBondGlobals");
        globals.upload(globalParamValues);
5313
5314
5315
5316
5317
5318
5319
5320
5321
5322
5323
5324
5325
5326
5327
5328
5329
5330
5331
5332
5333
5334
5335
        extraArgs << ", __global const float* restrict globals";
        for (int i = 0; i < force.getNumGlobalParameters(); i++) {
            const string& name = force.getGlobalParameterName(i);
            string value = "globals["+cl.intToString(i)+"]";
            variables[name] = value;
        }
    }

    // Now to generate the kernel.  First, it needs to calculate all distances, angles,
    // and dihedrals the expression depends on.

    map<string, vector<int> > distances;
    map<string, vector<int> > angles;
    map<string, vector<int> > dihedrals;
    Lepton::ParsedExpression energyExpression = CustomCentroidBondForceImpl::prepareExpression(force, functions, distances, angles, dihedrals);
    map<string, Lepton::ParsedExpression> forceExpressions;
    set<string> computedDeltas;
    vector<string> atomNames, posNames;
    for (int i = 0; i < groupsPerBond; i++) {
        string index = cl.intToString(i+1);
        atomNames.push_back("P"+index);
        posNames.push_back("pos"+index);
    }
5336
    stringstream compute, initParamDerivs, saveParamDerivs;
5337
5338
5339
5340
5341
    for (int i = 0; i < groupsPerBond; i++) {
        compute<<"int group"<<(i+1)<<" = bondGroups[index+"<<(i*numBonds)<<"];\n";
        compute<<"real4 pos"<<(i+1)<<" = centerPositions[group"<<(i+1)<<"];\n";
    }
    int index = 0;
peastman's avatar
peastman committed
5342
5343
    for (auto& distance : distances) {
        const vector<int>& groups = distance.second;
5344
5345
        string deltaName = atomNames[groups[0]]+atomNames[groups[1]];
        if (computedDeltas.count(deltaName) == 0) {
5346
            compute<<"real4 delta"<<deltaName<<" = delta("<<posNames[groups[0]]<<", "<<posNames[groups[1]]<<", "<<force.usesPeriodicBoundaryConditions()<<", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);\n";
5347
5348
5349
            computedDeltas.insert(deltaName);
        }
        compute<<"real r_"<<deltaName<<" = sqrt(delta"<<deltaName<<".w);\n";
peastman's avatar
peastman committed
5350
5351
5352
        variables[distance.first] = "r_"+deltaName;
        forceExpressions["real dEdDistance"+cl.intToString(index)+" = "] = energyExpression.differentiate(distance.first).optimize();
        index++;
5353
5354
    }
    index = 0;
peastman's avatar
peastman committed
5355
5356
    for (auto& angle : angles) {
        const vector<int>& groups = angle.second;
5357
5358
5359
5360
        string deltaName1 = atomNames[groups[1]]+atomNames[groups[0]];
        string deltaName2 = atomNames[groups[1]]+atomNames[groups[2]];
        string angleName = "angle_"+atomNames[groups[0]]+atomNames[groups[1]]+atomNames[groups[2]];
        if (computedDeltas.count(deltaName1) == 0) {
5361
            compute<<"real4 delta"<<deltaName1<<" = delta("<<posNames[groups[1]]<<", "<<posNames[groups[0]]<<", "<<force.usesPeriodicBoundaryConditions()<<", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);\n";
5362
5363
5364
            computedDeltas.insert(deltaName1);
        }
        if (computedDeltas.count(deltaName2) == 0) {
5365
            compute<<"real4 delta"<<deltaName2<<" = delta("<<posNames[groups[1]]<<", "<<posNames[groups[2]]<<", "<<force.usesPeriodicBoundaryConditions()<<", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);\n";
5366
5367
5368
            computedDeltas.insert(deltaName2);
        }
        compute<<"real "<<angleName<<" = computeAngle(delta"<<deltaName1<<", delta"<<deltaName2<<");\n";
peastman's avatar
peastman committed
5369
5370
5371
        variables[angle.first] = angleName;
        forceExpressions["real dEdAngle"+cl.intToString(index)+" = "] = energyExpression.differentiate(angle.first).optimize();
        index++;
5372
5373
    }
    index = 0;
peastman's avatar
peastman committed
5374
5375
    for (auto& dihedral : dihedrals) {
        const vector<int>& groups = dihedral.second;
5376
5377
5378
5379
5380
5381
5382
        string deltaName1 = atomNames[groups[0]]+atomNames[groups[1]];
        string deltaName2 = atomNames[groups[2]]+atomNames[groups[1]];
        string deltaName3 = atomNames[groups[2]]+atomNames[groups[3]];
        string crossName1 = "cross_"+deltaName1+"_"+deltaName2;
        string crossName2 = "cross_"+deltaName2+"_"+deltaName3;
        string dihedralName = "dihedral_"+atomNames[groups[0]]+atomNames[groups[1]]+atomNames[groups[2]]+atomNames[groups[3]];
        if (computedDeltas.count(deltaName1) == 0) {
5383
            compute<<"real4 delta"<<deltaName1<<" = delta("<<posNames[groups[0]]<<", "<<posNames[groups[1]]<<", "<<force.usesPeriodicBoundaryConditions()<<", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);\n";
5384
5385
5386
            computedDeltas.insert(deltaName1);
        }
        if (computedDeltas.count(deltaName2) == 0) {
5387
            compute<<"real4 delta"<<deltaName2<<" = delta("<<posNames[groups[2]]<<", "<<posNames[groups[1]]<<", "<<force.usesPeriodicBoundaryConditions()<<", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);\n";
5388
5389
5390
            computedDeltas.insert(deltaName2);
        }
        if (computedDeltas.count(deltaName3) == 0) {
5391
            compute<<"real4 delta"<<deltaName3<<" = delta("<<posNames[groups[2]]<<", "<<posNames[groups[3]]<<", "<<force.usesPeriodicBoundaryConditions()<<", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);\n";
5392
5393
5394
5395
5396
5397
            computedDeltas.insert(deltaName3);
        }
        compute<<"real4 "<<crossName1<<" = computeCross(delta"<<deltaName1<<", delta"<<deltaName2<<");\n";
        compute<<"real4 "<<crossName2<<" = computeCross(delta"<<deltaName2<<", delta"<<deltaName3<<");\n";
        compute<<"real "<<dihedralName<<" = computeAngle("<<crossName1<<", "<<crossName2<<");\n";
        compute<<dihedralName<<" *= (delta"<<deltaName1<<".x*"<<crossName2<<".x + delta"<<deltaName1<<".y*"<<crossName2<<".y + delta"<<deltaName1<<".z*"<<crossName2<<".z < 0 ? -1 : 1);\n";
peastman's avatar
peastman committed
5398
5399
5400
        variables[dihedral.first] = dihedralName;
        forceExpressions["real dEdDihedral"+cl.intToString(index)+" = "] = energyExpression.differentiate(dihedral.first).optimize();
        index++;
5401
5402
5403
5404
5405
5406
5407
5408
5409
5410
    }

    // Now evaluate the expressions.

    for (int i = 0; i < (int) params->getBuffers().size(); i++) {
        OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i];
        extraArgs<<", __global const "<<buffer.getType()<<"* restrict globalParams"<<i;
        compute<<buffer.getType()<<" bondParams"<<(i+1)<<" = globalParams"<<i<<"[index];\n";
    }
    forceExpressions["energy += "] = energyExpression;
5411
5412
5413
5414
5415
5416
5417
5418
5419
5420
5421
5422
5423
5424
5425
    if (needEnergyParamDerivs) {
        for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
            string paramName = force.getEnergyParameterDerivativeName(i);
            cl.addEnergyParameterDerivative(paramName);
            Lepton::ParsedExpression derivExpression = energyExpression.differentiate(paramName).optimize();
            forceExpressions[string("energyParamDeriv")+cl.intToString(i)+" += "] = derivExpression;
            initParamDerivs << "mixed energyParamDeriv" << i << " = 0;\n";
        }
        const vector<string>& allParamDerivNames = cl.getEnergyParamDerivNames();
        int numDerivs = allParamDerivNames.size();
        for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++)
            for (int index = 0; index < numDerivs; index++)
                if (allParamDerivNames[index] == force.getEnergyParameterDerivativeName(i))
                    saveParamDerivs << "energyParamDerivs[get_global_id(0)*" << numDerivs << "+" << index << "] += energyParamDeriv" << i << ";\n";
    }
5426
5427
5428
5429
5430
5431
5432
5433
5434
5435
5436
5437
5438
5439
5440
5441
5442
5443
5444
5445
5446
5447
5448
5449
5450
5451
    compute << cl.getExpressionUtilities().createExpressions(forceExpressions, variables, functionList, functionDefinitions, "temp");

    // Finally, apply forces to groups.

    vector<string> forceNames;
    for (int i = 0; i < groupsPerBond; i++) {
        string istr = cl.intToString(i+1);
        string forceName = "force"+istr;
        forceNames.push_back(forceName);
        compute<<"real3 "<<forceName<<" = (real3) 0;\n";
        compute<<"{\n";
        Lepton::ParsedExpression forceExpressionX = energyExpression.differentiate("x"+istr).optimize();
        Lepton::ParsedExpression forceExpressionY = energyExpression.differentiate("y"+istr).optimize();
        Lepton::ParsedExpression forceExpressionZ = energyExpression.differentiate("z"+istr).optimize();
        map<string, Lepton::ParsedExpression> expressions;
        if (!isZeroExpression(forceExpressionX))
            expressions[forceName+".x -= "] = forceExpressionX;
        if (!isZeroExpression(forceExpressionY))
            expressions[forceName+".y -= "] = forceExpressionY;
        if (!isZeroExpression(forceExpressionZ))
            expressions[forceName+".z -= "] = forceExpressionZ;
        if (expressions.size() > 0)
            compute<<cl.getExpressionUtilities().createExpressions(expressions, variables, functionList, functionDefinitions, "coordtemp");
        compute<<"}\n";
    }
    index = 0;
peastman's avatar
peastman committed
5452
5453
    for (auto& distance : distances) {
        const vector<int>& groups = distance.second;
5454
5455
5456
5457
        string deltaName = atomNames[groups[0]]+atomNames[groups[1]];
        string value = "(dEdDistance"+cl.intToString(index)+"/r_"+deltaName+")*delta"+deltaName+".xyz";
        compute<<forceNames[groups[0]]<<" += "<<"-"<<value<<";\n";
        compute<<forceNames[groups[1]]<<" += "<<value<<";\n";
peastman's avatar
peastman committed
5458
        index++;
5459
5460
    }
    index = 0;
peastman's avatar
peastman committed
5461
5462
    for (auto& angle : angles) {
        const vector<int>& groups = angle.second;
5463
5464
5465
5466
5467
5468
5469
5470
5471
5472
5473
5474
        string deltaName1 = atomNames[groups[1]]+atomNames[groups[0]];
        string deltaName2 = atomNames[groups[1]]+atomNames[groups[2]];
        compute<<"{\n";
        compute<<"real4 crossProd = cross(delta"<<deltaName2<<", delta"<<deltaName1<<");\n";
        compute<<"real lengthCross = max(length(crossProd), (real) 1e-6f);\n";
        compute<<"real4 deltaCross0 = -cross(delta"<<deltaName1<<", crossProd)*dEdAngle"<<cl.intToString(index)<<"/(delta"<<deltaName1<<".w*lengthCross);\n";
        compute<<"real4 deltaCross2 = cross(delta"<<deltaName2<<", crossProd)*dEdAngle"<<cl.intToString(index)<<"/(delta"<<deltaName2<<".w*lengthCross);\n";
        compute<<"real4 deltaCross1 = -(deltaCross0+deltaCross2);\n";
        compute<<forceNames[groups[0]]<<".xyz += deltaCross0.xyz;\n";
        compute<<forceNames[groups[1]]<<".xyz += deltaCross1.xyz;\n";
        compute<<forceNames[groups[2]]<<".xyz += deltaCross2.xyz;\n";
        compute<<"}\n";
peastman's avatar
peastman committed
5475
        index++;
5476
5477
    }
    index = 0;
peastman's avatar
peastman committed
5478
5479
    for (auto& dihedral : dihedrals) {
        const vector<int>& groups = dihedral.second;
5480
5481
5482
5483
5484
5485
5486
5487
5488
5489
5490
5491
5492
5493
5494
5495
5496
5497
5498
5499
        string deltaName1 = atomNames[groups[0]]+atomNames[groups[1]];
        string deltaName2 = atomNames[groups[2]]+atomNames[groups[1]];
        string deltaName3 = atomNames[groups[2]]+atomNames[groups[3]];
        string crossName1 = "cross_"+deltaName1+"_"+deltaName2;
        string crossName2 = "cross_"+deltaName2+"_"+deltaName3;
        compute<<"{\n";
        compute<<"real r = sqrt(delta"<<deltaName2<<".w);\n";
        compute<<"real4 ff;\n";
        compute<<"ff.x = (-dEdDihedral"<<cl.intToString(index)<<"*r)/"<<crossName1<<".w;\n";
        compute<<"ff.y = (delta"<<deltaName1<<".x*delta"<<deltaName2<<".x + delta"<<deltaName1<<".y*delta"<<deltaName2<<".y + delta"<<deltaName1<<".z*delta"<<deltaName2<<".z)/delta"<<deltaName2<<".w;\n";
        compute<<"ff.z = (delta"<<deltaName3<<".x*delta"<<deltaName2<<".x + delta"<<deltaName3<<".y*delta"<<deltaName2<<".y + delta"<<deltaName3<<".z*delta"<<deltaName2<<".z)/delta"<<deltaName2<<".w;\n";
        compute<<"ff.w = (dEdDihedral"<<cl.intToString(index)<<"*r)/"<<crossName2<<".w;\n";
        compute<<"real4 internalF0 = ff.x*"<<crossName1<<";\n";
        compute<<"real4 internalF3 = ff.w*"<<crossName2<<";\n";
        compute<<"real4 s = ff.y*internalF0 - ff.z*internalF3;\n";
        compute<<forceNames[groups[0]]<<".xyz += internalF0.xyz;\n";
        compute<<forceNames[groups[1]]<<".xyz += s.xyz-internalF0.xyz;\n";
        compute<<forceNames[groups[2]]<<".xyz += -s.xyz-internalF3.xyz;\n";
        compute<<forceNames[groups[3]]<<".xyz += internalF3.xyz;\n";
        compute<<"}\n";
peastman's avatar
peastman committed
5500
        index++;
5501
5502
5503
5504
5505
5506
5507
5508
5509
5510
5511
5512
5513
5514
5515
5516
    }
    
    // Save the forces to global memory.
    
    for (int i = 0; i < groupsPerBond; i++) {
        compute<<"atom_add(&groupForce[group"<<(i+1)<<"], (long) (force"<<(i+1)<<".x*0x100000000));\n";
        compute<<"atom_add(&groupForce[group"<<(i+1)<<"+NUM_GROUPS], (long) (force"<<(i+1)<<".y*0x100000000));\n";
        compute<<"atom_add(&groupForce[group"<<(i+1)<<"+NUM_GROUPS*2], (long) (force"<<(i+1)<<".z*0x100000000));\n";
    }
    map<string, string> replacements;
    replacements["M_PI"] = cl.doubleToString(M_PI);
    replacements["NUM_GROUPS"] = cl.intToString(numGroups);
    replacements["NUM_BONDS"] = cl.intToString(numBonds);
    replacements["PADDED_NUM_ATOMS"] = cl.intToString(cl.getPaddedNumAtoms());
    replacements["EXTRA_ARGS"] = extraArgs.str();
    replacements["COMPUTE_FORCE"] = compute.str();
5517
5518
    replacements["INIT_PARAM_DERIVS"] = initParamDerivs.str();
    replacements["SAVE_PARAM_DERIVS"] = saveParamDerivs.str();
5519
5520
5521
5522
    cl::Program program = cl.createProgram(cl.replaceStrings(OpenCLKernelSources::customCentroidBond, replacements));
    index = 0;
    computeCentersKernel = cl::Kernel(program, "computeGroupCenters");
    computeCentersKernel.setArg<cl::Buffer>(index++, cl.getPosq().getDeviceBuffer());
peastman's avatar
peastman committed
5523
5524
5525
5526
    computeCentersKernel.setArg<cl::Buffer>(index++, groupParticles.getDeviceBuffer());
    computeCentersKernel.setArg<cl::Buffer>(index++, groupWeights.getDeviceBuffer());
    computeCentersKernel.setArg<cl::Buffer>(index++, groupOffsets.getDeviceBuffer());
    computeCentersKernel.setArg<cl::Buffer>(index++, centerPositions.getDeviceBuffer());
5527
5528
    index = 0;
    groupForcesKernel = cl::Kernel(program, "computeGroupForces");
peastman's avatar
peastman committed
5529
    groupForcesKernel.setArg<cl::Buffer>(index++, groupForces.getDeviceBuffer());
5530
    index++; // Energy buffer hasn't been created yet
peastman's avatar
peastman committed
5531
5532
    groupForcesKernel.setArg<cl::Buffer>(index++, centerPositions.getDeviceBuffer());
    groupForcesKernel.setArg<cl::Buffer>(index++, bondGroups.getDeviceBuffer());
5533
    index += 5; // Periodic box information
5534
5535
    if (needEnergyParamDerivs)
        index++; // Deriv buffer hasn't been created yet.
peastman's avatar
peastman committed
5536
5537
5538
5539
    for (auto& function : tabulatedFunctions)
        groupForcesKernel.setArg<cl::Buffer>(index++, function.getDeviceBuffer());
    if (globals.isInitialized())
        groupForcesKernel.setArg<cl::Buffer>(index++, globals.getDeviceBuffer());
peastman's avatar
peastman committed
5540
5541
    for (auto& buffer : params->getBuffers())
        groupForcesKernel.setArg<cl::Memory>(index++, buffer.getMemory());
5542
5543
    index = 0;
    applyForcesKernel = cl::Kernel(program, "applyForcesToAtoms");
peastman's avatar
peastman committed
5544
5545
5546
5547
    applyForcesKernel.setArg<cl::Buffer>(index++, groupParticles.getDeviceBuffer());
    applyForcesKernel.setArg<cl::Buffer>(index++, groupWeights.getDeviceBuffer());
    applyForcesKernel.setArg<cl::Buffer>(index++, groupOffsets.getDeviceBuffer());
    applyForcesKernel.setArg<cl::Buffer>(index++, groupForces.getDeviceBuffer());
5548
5549
5550
}

double OpenCLCalcCustomCentroidBondForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
5551
5552
    if (numBonds == 0)
        return 0.0;
peastman's avatar
peastman committed
5553
    if (globals.isInitialized()) {
5554
5555
5556
5557
5558
5559
5560
5561
        bool changed = false;
        for (int i = 0; i < (int) globalParamNames.size(); i++) {
            float value = (float) context.getParameter(globalParamNames[i]);
            if (value != globalParamValues[i])
                changed = true;
            globalParamValues[i] = value;
        }
        if (changed)
peastman's avatar
peastman committed
5562
            globals.upload(globalParamValues);
5563
5564
5565
    }
    cl.executeKernel(computeCentersKernel, OpenCLContext::TileSize*numGroups);
    groupForcesKernel.setArg<cl::Buffer>(1, cl.getEnergyBuffer().getDeviceBuffer());
5566
    setPeriodicBoxArgs(cl, groupForcesKernel, 4);
5567
5568
    if (needEnergyParamDerivs)
        groupForcesKernel.setArg<cl::Memory>(9, cl.getEnergyParamDerivBuffer().getDeviceBuffer());
5569
5570
5571
5572
5573
5574
5575
    cl.executeKernel(groupForcesKernel, numBonds);
    applyForcesKernel.setArg<cl::Buffer>(4, cl.getLongForceBuffer().getDeviceBuffer());
    cl.executeKernel(applyForcesKernel, OpenCLContext::TileSize*numGroups);
    return 0.0;
}

void OpenCLCalcCustomCentroidBondForceKernel::copyParametersToContext(ContextImpl& context, const CustomCentroidBondForce& force) {
5576
    if (numBonds != force.getNumBonds())
5577
5578
5579
5580
5581
5582
5583
5584
5585
5586
        throw OpenMMException("updateParametersInContext: The number of bonds has changed");
    if (numBonds == 0)
        return;
    
    // Record the per-bond parameters.
    
    vector<vector<float> > paramVector(numBonds);
    vector<int> particles;
    vector<double> parameters;
    for (int i = 0; i < numBonds; i++) {
5587
        force.getBondParameters(i, particles, parameters);
5588
5589
5590
5591
5592
5593
5594
5595
        paramVector[i].resize(parameters.size());
        for (int j = 0; j < (int) parameters.size(); j++)
            paramVector[i][j] = (float) parameters[j];
    }
    params->setParameterValues(paramVector);
    
    // Mark that the current reordering may be invalid.
    
5596
    cl.invalidateMolecules(info);
5597
5598
}

5599
class OpenCLCalcCustomCompoundBondForceKernel::ForceInfo : public OpenCLForceInfo {
5600
public:
5601
    ForceInfo(const CustomCompoundBondForce& force) : OpenCLForceInfo(0), force(force) {
5602
5603
5604
5605
5606
5607
5608
5609
5610
5611
5612
5613
5614
5615
5616
5617
5618
5619
5620
5621
5622
5623
5624
5625
5626
5627
5628
5629
5630
5631
5632
5633
5634
5635
5636
5637
5638
5639
5640
5641
5642
5643
5644
5645
5646
5647
    }
    int getNumParticleGroups() {
        return force.getNumBonds();
    }
    void getParticlesInGroup(int index, vector<int>& particles) {
        vector<double> parameters;
        force.getBondParameters(index, particles, parameters);
    }
    bool areGroupsIdentical(int group1, int group2) {
        vector<int> particles;
        vector<double> parameters1, parameters2;
        force.getBondParameters(group1, particles, parameters1);
        force.getBondParameters(group2, particles, parameters2);
        for (int i = 0; i < (int) parameters1.size(); i++)
            if (parameters1[i] != parameters2[i])
                return false;
        return true;
    }
private:
    const CustomCompoundBondForce& force;
};

OpenCLCalcCustomCompoundBondForceKernel::~OpenCLCalcCustomCompoundBondForceKernel() {
    if (params != NULL)
        delete params;
}

void OpenCLCalcCustomCompoundBondForceKernel::initialize(const System& system, const CustomCompoundBondForce& force) {
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*force.getNumBonds()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*force.getNumBonds()/numContexts;
    numBonds = endIndex-startIndex;
    if (numBonds == 0)
        return;
    int particlesPerBond = force.getNumParticlesPerBond();
    vector<vector<int> > atoms(numBonds, vector<int>(particlesPerBond));
    params = new OpenCLParameterSet(cl, force.getNumPerBondParameters(), numBonds, "customCompoundBondParams");
    vector<vector<cl_float> > paramVector(numBonds);
    for (int i = 0; i < numBonds; i++) {
        vector<double> parameters;
        force.getBondParameters(startIndex+i, atoms[i], parameters);
        paramVector[i].resize(parameters.size());
        for (int j = 0; j < (int) parameters.size(); j++)
            paramVector[i][j] = (cl_float) parameters[j];
    }
    params->setParameterValues(paramVector);
5648
5649
    info = new ForceInfo(force);
    cl.addForce(info);
5650
5651
5652
5653
5654

    // Record the tabulated functions.

    map<string, Lepton::CustomFunction*> functions;
    vector<pair<string, string> > functionDefinitions;
5655
    vector<const TabulatedFunction*> functionList;
5656
    stringstream tableArgs;
peastman's avatar
peastman committed
5657
5658
    tabulatedFunctions.resize(force.getNumTabulatedFunctions());
    for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
5659
5660
5661
        functionList.push_back(&force.getTabulatedFunction(i));
        string name = force.getTabulatedFunctionName(i);
        functions[name] = cl.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
peastman's avatar
peastman committed
5662
        int width;
5663
        vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
peastman's avatar
peastman committed
5664
5665
5666
        tabulatedFunctions[i].initialize<float>(cl, f.size(), "TabulatedFunction");
        tabulatedFunctions[i].upload(f);
        string arrayName = cl.getBondedUtilities().addArgument(tabulatedFunctions[i].getDeviceBuffer(), width == 1 ? "float" : "float"+cl.intToString(width));
5667
5668
5669
5670
5671
5672
5673
5674
5675
5676
5677
5678
5679
        functionDefinitions.push_back(make_pair(name, arrayName));
    }
    
    // Record information about parameters.

    globalParamNames.resize(force.getNumGlobalParameters());
    globalParamValues.resize(force.getNumGlobalParameters());
    for (int i = 0; i < force.getNumGlobalParameters(); i++) {
        globalParamNames[i] = force.getGlobalParameterName(i);
        globalParamValues[i] = (cl_float) force.getGlobalParameterDefaultValue(i);
    }
    map<string, string> variables;
    for (int i = 0; i < particlesPerBond; i++) {
5680
        string index = cl.intToString(i+1);
5681
5682
5683
5684
5685
5686
5687
5688
5689
        variables["x"+index] = "pos"+index+".x";
        variables["y"+index] = "pos"+index+".y";
        variables["z"+index] = "pos"+index+".z";
    }
    for (int i = 0; i < force.getNumPerBondParameters(); i++) {
        const string& name = force.getPerBondParameterName(i);
        variables[name] = "bondParams"+params->getParameterSuffix(i);
    }
    if (force.getNumGlobalParameters() > 0) {
peastman's avatar
peastman committed
5690
5691
5692
        globals.initialize<cl_float>(cl, force.getNumGlobalParameters(), "customCompoundBondGlobals", CL_MEM_READ_ONLY);
        globals.upload(globalParamValues);
        string argName = cl.getBondedUtilities().addArgument(globals.getDeviceBuffer(), "float");
5693
5694
        for (int i = 0; i < force.getNumGlobalParameters(); i++) {
            const string& name = force.getGlobalParameterName(i);
5695
            string value = argName+"["+cl.intToString(i)+"]";
5696
5697
5698
5699
5700
5701
5702
5703
5704
5705
5706
5707
5708
5709
5710
            variables[name] = value;
        }
    }

    // Now to generate the kernel.  First, it needs to calculate all distances, angles,
    // and dihedrals the expression depends on.

    map<string, vector<int> > distances;
    map<string, vector<int> > angles;
    map<string, vector<int> > dihedrals;
    Lepton::ParsedExpression energyExpression = CustomCompoundBondForceImpl::prepareExpression(force, functions, distances, angles, dihedrals);
    map<string, Lepton::ParsedExpression> forceExpressions;
    set<string> computedDeltas;
    vector<string> atomNames, posNames;
    for (int i = 0; i < particlesPerBond; i++) {
5711
        string index = cl.intToString(i+1);
5712
5713
5714
5715
5716
        atomNames.push_back("P"+index);
        posNames.push_back("pos"+index);
    }
    stringstream compute;
    int index = 0;
peastman's avatar
peastman committed
5717
5718
    for (auto& distance : distances) {
        const vector<int>& atoms = distance.second;
5719
5720
        string deltaName = atomNames[atoms[0]]+atomNames[atoms[1]];
        if (computedDeltas.count(deltaName) == 0) {
5721
            compute<<"real4 delta"<<deltaName<<" = ccb_delta("<<posNames[atoms[0]]<<", "<<posNames[atoms[1]]<<", "<<force.usesPeriodicBoundaryConditions()<<", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);\n";
5722
5723
            computedDeltas.insert(deltaName);
        }
5724
        compute<<"real r_"<<deltaName<<" = sqrt(delta"<<deltaName<<".w);\n";
peastman's avatar
peastman committed
5725
5726
5727
        variables[distance.first] = "r_"+deltaName;
        forceExpressions["real dEdDistance"+cl.intToString(index)+" = "] = energyExpression.differentiate(distance.first).optimize();
        index++;
5728
5729
    }
    index = 0;
peastman's avatar
peastman committed
5730
5731
    for (auto& angle : angles) {
        const vector<int>& atoms = angle.second;
5732
5733
5734
5735
        string deltaName1 = atomNames[atoms[1]]+atomNames[atoms[0]];
        string deltaName2 = atomNames[atoms[1]]+atomNames[atoms[2]];
        string angleName = "angle_"+atomNames[atoms[0]]+atomNames[atoms[1]]+atomNames[atoms[2]];
        if (computedDeltas.count(deltaName1) == 0) {
5736
            compute<<"real4 delta"<<deltaName1<<" = ccb_delta("<<posNames[atoms[1]]<<", "<<posNames[atoms[0]]<<", "<<force.usesPeriodicBoundaryConditions()<<", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);\n";
5737
5738
5739
            computedDeltas.insert(deltaName1);
        }
        if (computedDeltas.count(deltaName2) == 0) {
5740
            compute<<"real4 delta"<<deltaName2<<" = ccb_delta("<<posNames[atoms[1]]<<", "<<posNames[atoms[2]]<<", "<<force.usesPeriodicBoundaryConditions()<<", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);\n";
5741
5742
            computedDeltas.insert(deltaName2);
        }
5743
        compute<<"real "<<angleName<<" = ccb_computeAngle(delta"<<deltaName1<<", delta"<<deltaName2<<");\n";
peastman's avatar
peastman committed
5744
5745
5746
        variables[angle.first] = angleName;
        forceExpressions["real dEdAngle"+cl.intToString(index)+" = "] = energyExpression.differentiate(angle.first).optimize();
        index++;
5747
5748
    }
    index = 0;
peastman's avatar
peastman committed
5749
5750
    for (auto& dihedral : dihedrals) {
        const vector<int>& atoms = dihedral.second;
5751
5752
5753
5754
5755
5756
5757
        string deltaName1 = atomNames[atoms[0]]+atomNames[atoms[1]];
        string deltaName2 = atomNames[atoms[2]]+atomNames[atoms[1]];
        string deltaName3 = atomNames[atoms[2]]+atomNames[atoms[3]];
        string crossName1 = "cross_"+deltaName1+"_"+deltaName2;
        string crossName2 = "cross_"+deltaName2+"_"+deltaName3;
        string dihedralName = "dihedral_"+atomNames[atoms[0]]+atomNames[atoms[1]]+atomNames[atoms[2]]+atomNames[atoms[3]];
        if (computedDeltas.count(deltaName1) == 0) {
5758
            compute<<"real4 delta"<<deltaName1<<" = ccb_delta("<<posNames[atoms[0]]<<", "<<posNames[atoms[1]]<<", "<<force.usesPeriodicBoundaryConditions()<<", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);\n";
5759
5760
5761
            computedDeltas.insert(deltaName1);
        }
        if (computedDeltas.count(deltaName2) == 0) {
5762
            compute<<"real4 delta"<<deltaName2<<" = ccb_delta("<<posNames[atoms[2]]<<", "<<posNames[atoms[1]]<<", "<<force.usesPeriodicBoundaryConditions()<<", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);\n";
5763
5764
5765
            computedDeltas.insert(deltaName2);
        }
        if (computedDeltas.count(deltaName3) == 0) {
5766
            compute<<"real4 delta"<<deltaName3<<" = ccb_delta("<<posNames[atoms[2]]<<", "<<posNames[atoms[3]]<<", "<<force.usesPeriodicBoundaryConditions()<<", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);\n";
5767
5768
            computedDeltas.insert(deltaName3);
        }
5769
5770
5771
        compute<<"real4 "<<crossName1<<" = ccb_computeCross(delta"<<deltaName1<<", delta"<<deltaName2<<");\n";
        compute<<"real4 "<<crossName2<<" = ccb_computeCross(delta"<<deltaName2<<", delta"<<deltaName3<<");\n";
        compute<<"real "<<dihedralName<<" = ccb_computeAngle("<<crossName1<<", "<<crossName2<<");\n";
5772
        compute<<dihedralName<<" *= (delta"<<deltaName1<<".x*"<<crossName2<<".x + delta"<<deltaName1<<".y*"<<crossName2<<".y + delta"<<deltaName1<<".z*"<<crossName2<<".z < 0 ? -1 : 1);\n";
peastman's avatar
peastman committed
5773
5774
5775
        variables[dihedral.first] = dihedralName;
        forceExpressions["real dEdDihedral"+cl.intToString(index)+" = "] = energyExpression.differentiate(dihedral.first).optimize();
        index++;
5776
5777
5778
5779
5780
5781
5782
5783
5784
5785
    }

    // Now evaluate the expressions.

    for (int i = 0; i < (int) params->getBuffers().size(); i++) {
        const OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i];
        string argName = cl.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType());
        compute<<buffer.getType()<<" bondParams"<<(i+1)<<" = "<<argName<<"[index];\n";
    }
    forceExpressions["energy += "] = energyExpression;
5786
5787
5788
5789
5790
5791
    for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
        string paramName = force.getEnergyParameterDerivativeName(i);
        string derivVariable = cl.getBondedUtilities().addEnergyParameterDerivative(paramName);
        Lepton::ParsedExpression derivExpression = energyExpression.differentiate(paramName).optimize();
        forceExpressions[derivVariable+" += "] = derivExpression;
    }
5792
    compute << cl.getExpressionUtilities().createExpressions(forceExpressions, variables, functionList, functionDefinitions, "temp");
5793
5794
5795
5796
5797

    // Finally, apply forces to atoms.

    vector<string> forceNames;
    for (int i = 0; i < particlesPerBond; i++) {
5798
        string istr = cl.intToString(i+1);
5799
5800
        string forceName = "force"+istr;
        forceNames.push_back(forceName);
5801
        compute<<"real4 "<<forceName<<" = (real4) 0;\n";
5802
5803
5804
5805
5806
5807
5808
5809
5810
5811
5812
5813
        compute<<"{\n";
        Lepton::ParsedExpression forceExpressionX = energyExpression.differentiate("x"+istr).optimize();
        Lepton::ParsedExpression forceExpressionY = energyExpression.differentiate("y"+istr).optimize();
        Lepton::ParsedExpression forceExpressionZ = energyExpression.differentiate("z"+istr).optimize();
        map<string, Lepton::ParsedExpression> expressions;
        if (!isZeroExpression(forceExpressionX))
            expressions[forceName+".x -= "] = forceExpressionX;
        if (!isZeroExpression(forceExpressionY))
            expressions[forceName+".y -= "] = forceExpressionY;
        if (!isZeroExpression(forceExpressionZ))
            expressions[forceName+".z -= "] = forceExpressionZ;
        if (expressions.size() > 0)
5814
            compute<<cl.getExpressionUtilities().createExpressions(expressions, variables, functionList, functionDefinitions, "coordtemp");
5815
5816
5817
        compute<<"}\n";
    }
    index = 0;
peastman's avatar
peastman committed
5818
5819
    for (auto& distance : distances) {
        const vector<int>& atoms = distance.second;
5820
        string deltaName = atomNames[atoms[0]]+atomNames[atoms[1]];
5821
        string value = "(dEdDistance"+cl.intToString(index)+"/r_"+deltaName+")*delta"+deltaName+".xyz";
5822
5823
        compute<<forceNames[atoms[0]]<<".xyz += "<<"-"<<value<<";\n";
        compute<<forceNames[atoms[1]]<<".xyz += "<<value<<";\n";
peastman's avatar
peastman committed
5824
        index++;
5825
5826
    }
    index = 0;
peastman's avatar
peastman committed
5827
5828
    for (auto& angle : angles) {
        const vector<int>& atoms = angle.second;
5829
5830
5831
        string deltaName1 = atomNames[atoms[1]]+atomNames[atoms[0]];
        string deltaName2 = atomNames[atoms[1]]+atomNames[atoms[2]];
        compute<<"{\n";
5832
5833
5834
5835
5836
        compute<<"real4 crossProd = cross(delta"<<deltaName2<<", delta"<<deltaName1<<");\n";
        compute<<"real lengthCross = max(length(crossProd), (real) 1e-6f);\n";
        compute<<"real4 deltaCross0 = -cross(delta"<<deltaName1<<", crossProd)*dEdAngle"<<cl.intToString(index)<<"/(delta"<<deltaName1<<".w*lengthCross);\n";
        compute<<"real4 deltaCross2 = cross(delta"<<deltaName2<<", crossProd)*dEdAngle"<<cl.intToString(index)<<"/(delta"<<deltaName2<<".w*lengthCross);\n";
        compute<<"real4 deltaCross1 = -(deltaCross0+deltaCross2);\n";
5837
5838
5839
5840
        compute<<forceNames[atoms[0]]<<".xyz += deltaCross0.xyz;\n";
        compute<<forceNames[atoms[1]]<<".xyz += deltaCross1.xyz;\n";
        compute<<forceNames[atoms[2]]<<".xyz += deltaCross2.xyz;\n";
        compute<<"}\n";
peastman's avatar
peastman committed
5841
        index++;
5842
5843
    }
    index = 0;
peastman's avatar
peastman committed
5844
5845
    for (auto& dihedral : dihedrals) {
        const vector<int>& atoms = dihedral.second;
5846
5847
5848
5849
5850
5851
        string deltaName1 = atomNames[atoms[0]]+atomNames[atoms[1]];
        string deltaName2 = atomNames[atoms[2]]+atomNames[atoms[1]];
        string deltaName3 = atomNames[atoms[2]]+atomNames[atoms[3]];
        string crossName1 = "cross_"+deltaName1+"_"+deltaName2;
        string crossName2 = "cross_"+deltaName2+"_"+deltaName3;
        compute<<"{\n";
5852
5853
5854
        compute<<"real r = SQRT(delta"<<deltaName2<<".w);\n";
        compute<<"real4 ff;\n";
        compute<<"ff.x = (-dEdDihedral"<<cl.intToString(index)<<"*r)/"<<crossName1<<".w;\n";
5855
5856
        compute<<"ff.y = (delta"<<deltaName1<<".x*delta"<<deltaName2<<".x + delta"<<deltaName1<<".y*delta"<<deltaName2<<".y + delta"<<deltaName1<<".z*delta"<<deltaName2<<".z)/delta"<<deltaName2<<".w;\n";
        compute<<"ff.z = (delta"<<deltaName3<<".x*delta"<<deltaName2<<".x + delta"<<deltaName3<<".y*delta"<<deltaName2<<".y + delta"<<deltaName3<<".z*delta"<<deltaName2<<".z)/delta"<<deltaName2<<".w;\n";
5857
5858
5859
5860
        compute<<"ff.w = (dEdDihedral"<<cl.intToString(index)<<"*r)/"<<crossName2<<".w;\n";
        compute<<"real4 internalF0 = ff.x*"<<crossName1<<";\n";
        compute<<"real4 internalF3 = ff.w*"<<crossName2<<";\n";
        compute<<"real4 s = ff.y*internalF0 - ff.z*internalF3;\n";
5861
5862
5863
5864
5865
        compute<<forceNames[atoms[0]]<<".xyz += internalF0.xyz;\n";
        compute<<forceNames[atoms[1]]<<".xyz += s.xyz-internalF0.xyz;\n";
        compute<<forceNames[atoms[2]]<<".xyz += -s.xyz-internalF3.xyz;\n";
        compute<<forceNames[atoms[3]]<<".xyz += internalF3.xyz;\n";
        compute<<"}\n";
peastman's avatar
peastman committed
5866
        index++;
5867
5868
5869
    }
    cl.getBondedUtilities().addInteraction(atoms, compute.str(), force.getForceGroup());
    map<string, string> replacements;
5870
    replacements["M_PI"] = cl.doubleToString(M_PI);
5871
5872
5873
5874
    cl.getBondedUtilities().addPrefixCode(cl.replaceStrings(OpenCLKernelSources::customCompoundBond, replacements));;
}

double OpenCLCalcCustomCompoundBondForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
peastman's avatar
peastman committed
5875
    if (globals.isInitialized()) {
5876
5877
5878
5879
5880
5881
5882
5883
        bool changed = false;
        for (int i = 0; i < (int) globalParamNames.size(); i++) {
            cl_float value = (cl_float) context.getParameter(globalParamNames[i]);
            if (value != globalParamValues[i])
                changed = true;
            globalParamValues[i] = value;
        }
        if (changed)
peastman's avatar
peastman committed
5884
            globals.upload(globalParamValues);
5885
5886
5887
5888
    }
    return 0.0;
}

5889
5890
5891
5892
5893
5894
void OpenCLCalcCustomCompoundBondForceKernel::copyParametersToContext(ContextImpl& context, const CustomCompoundBondForce& force) {
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*force.getNumBonds()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*force.getNumBonds()/numContexts;
    if (numBonds != endIndex-startIndex)
        throw OpenMMException("updateParametersInContext: The number of bonds has changed");
5895
5896
    if (numBonds == 0)
        return;
5897
5898
5899
5900
5901
5902
5903
5904
5905
5906
5907
5908
5909
5910
5911
5912
    
    // Record the per-bond parameters.
    
    vector<vector<cl_float> > paramVector(numBonds);
    vector<int> particles;
    vector<double> parameters;
    for (int i = 0; i < numBonds; i++) {
        force.getBondParameters(startIndex+i, particles, parameters);
        paramVector[i].resize(parameters.size());
        for (int j = 0; j < (int) parameters.size(); j++)
            paramVector[i][j] = (cl_float) parameters[j];
    }
    params->setParameterValues(paramVector);
    
    // Mark that the current reordering may be invalid.
    
5913
    cl.invalidateMolecules(info);
5914
5915
}

5916
class OpenCLCalcCustomManyParticleForceKernel::ForceInfo : public OpenCLForceInfo {
5917
public:
5918
    ForceInfo(const CustomManyParticleForce& force) : OpenCLForceInfo(0), force(force) {
5919
5920
5921
5922
5923
5924
5925
5926
5927
5928
5929
5930
5931
5932
5933
5934
5935
5936
5937
5938
5939
5940
5941
5942
5943
5944
5945
5946
5947
5948
5949
5950
5951
5952
5953
5954
5955
5956
5957
5958
5959
5960
5961
5962
5963
5964
5965
5966
5967
5968
5969
5970
5971
5972
5973
5974
5975
5976
    }
    bool areParticlesIdentical(int particle1, int particle2) {
        vector<double> params1, params2;
        int type1, type2;
        force.getParticleParameters(particle1, params1, type1);
        force.getParticleParameters(particle2, params2, type2);
        if (type1 != type2)
            return false;
        for (int i = 0; i < (int) params1.size(); i++)
            if (params1[i] != params2[i])
                return false;
        return true;
    }
    int getNumParticleGroups() {
        return force.getNumExclusions();
    }
    void getParticlesInGroup(int index, vector<int>& particles) {
        int particle1, particle2;
        force.getExclusionParticles(index, particle1, particle2);
        particles.resize(2);
        particles[0] = particle1;
        particles[1] = particle2;
    }
    bool areGroupsIdentical(int group1, int group2) {
        return true;
    }
private:
    const CustomManyParticleForce& force;
};

OpenCLCalcCustomManyParticleForceKernel::~OpenCLCalcCustomManyParticleForceKernel() {
    if (params != NULL)
        delete params;
}

void OpenCLCalcCustomManyParticleForceKernel::initialize(const System& system, const CustomManyParticleForce& force) {
    if (!cl.getSupports64BitGlobalAtomics())
        throw OpenMMException("CustomManyParticleForce requires a device that supports 64 bit atomic operations");
    int numParticles = force.getNumParticles();
    int particlesPerSet = force.getNumParticlesPerSet();
    bool centralParticleMode = (force.getPermutationMode() == CustomManyParticleForce::UniqueCentralParticle);
    nonbondedMethod = CalcCustomManyParticleForceKernel::NonbondedMethod(force.getNonbondedMethod());
    forceWorkgroupSize = 128;
    findNeighborsWorkgroupSize = (cl.getSIMDWidth() >= 32 ? 128 : 32);
    
    // Record parameter values.
    
    params = new OpenCLParameterSet(cl, force.getNumPerParticleParameters(), numParticles, "customManyParticleParameters");
    vector<vector<float> > paramVector(numParticles);
    for (int i = 0; i < numParticles; i++) {
        vector<double> parameters;
        int type;
        force.getParticleParameters(i, parameters, type);
        paramVector[i].resize(parameters.size());
        for (int j = 0; j < (int) parameters.size(); j++)
            paramVector[i][j] = (float) parameters[j];
    }
    params->setParameterValues(paramVector);
5977
5978
    info = new ForceInfo(force);
    cl.addForce(info);
5979
5980
5981
5982
5983
5984
5985

    // Record the tabulated functions.

    map<string, Lepton::CustomFunction*> functions;
    vector<pair<string, string> > functionDefinitions;
    vector<const TabulatedFunction*> functionList;
    stringstream tableArgs;
peastman's avatar
peastman committed
5986
    tabulatedFunctions.resize(force.getNumTabulatedFunctions());
5987
5988
5989
5990
5991
5992
5993
5994
    for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
        functionList.push_back(&force.getTabulatedFunction(i));
        string name = force.getTabulatedFunctionName(i);
        string arrayName = "table"+cl.intToString(i);
        functionDefinitions.push_back(make_pair(name, arrayName));
        functions[name] = cl.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
        int width;
        vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
peastman's avatar
peastman committed
5995
5996
        tabulatedFunctions[i].initialize<float>(cl, f.size(), "TabulatedFunction");
        tabulatedFunctions[i].upload(f);
5997
5998
5999
6000
6001
6002
6003
6004
6005
6006
6007
6008
6009
6010
6011
6012
6013
6014
6015
6016
6017
6018
6019
6020
6021
6022
6023
6024
6025
        tableArgs << ", __global const float";
        if (width > 1)
            tableArgs << width;
        tableArgs << "* restrict " << arrayName;
    }
    
    // Record information about parameters.

    globalParamNames.resize(force.getNumGlobalParameters());
    globalParamValues.resize(force.getNumGlobalParameters());
    for (int i = 0; i < force.getNumGlobalParameters(); i++) {
        globalParamNames[i] = force.getGlobalParameterName(i);
        globalParamValues[i] = (float) force.getGlobalParameterDefaultValue(i);
    }
    vector<pair<ExpressionTreeNode, string> > variables;
    for (int i = 0; i < particlesPerSet; i++) {
        string index = cl.intToString(i+1);
        variables.push_back(makeVariable("x"+index, "pos"+index+".x"));
        variables.push_back(makeVariable("y"+index, "pos"+index+".y"));
        variables.push_back(makeVariable("z"+index, "pos"+index+".z"));
    }
    for (int i = 0; i < force.getNumPerParticleParameters(); i++) {
        const string& name = force.getPerParticleParameterName(i);
        for (int j = 0; j < particlesPerSet; j++) {
            string index = cl.intToString(j+1);
            variables.push_back(makeVariable(name+index, "params"+params->getParameterSuffix(i, index)));
        }
    }
    if (force.getNumGlobalParameters() > 0) {
peastman's avatar
peastman committed
6026
6027
        globals.initialize<cl_float>(cl, force.getNumGlobalParameters(), "customManyParticleGlobals", CL_MEM_READ_ONLY);
        globals.upload(globalParamValues);
6028
6029
6030
6031
6032
6033
6034
6035
6036
6037
6038
6039
6040
6041
6042
6043
        for (int i = 0; i < force.getNumGlobalParameters(); i++) {
            const string& name = force.getGlobalParameterName(i);
            string value = "globals["+cl.intToString(i)+"]";
            variables.push_back(makeVariable(name, value));
        }
    }
    
    // Build data structures for type filters.
    
    vector<int> particleTypesVec;
    vector<int> orderIndexVec;
    vector<std::vector<int> > particleOrderVec;
    int numTypes;
    CustomManyParticleForceImpl::buildFilterArrays(force, numTypes, particleTypesVec, orderIndexVec, particleOrderVec);
    bool hasTypeFilters = (particleOrderVec.size() > 1);
    if (hasTypeFilters) {
peastman's avatar
peastman committed
6044
6045
6046
6047
6048
6049
        particleTypes.initialize<int>(cl, particleTypesVec.size(), "customManyParticleTypes");
        orderIndex.initialize<int>(cl, orderIndexVec.size(), "customManyParticleOrderIndex");
        particleOrder.initialize<int>(cl, particleOrderVec.size()*particlesPerSet, "customManyParticleOrder");
        particleTypes.upload(particleTypesVec);
        orderIndex.upload(orderIndexVec);
        vector<int> flattenedOrder(particleOrder.getSize());
6050
6051
6052
        for (int i = 0; i < (int) particleOrderVec.size(); i++)
            for (int j = 0; j < particlesPerSet; j++)
                flattenedOrder[i*particlesPerSet+j] = particleOrderVec[i][j];
peastman's avatar
peastman committed
6053
        particleOrder.upload(flattenedOrder);
6054
6055
6056
6057
6058
6059
6060
6061
6062
6063
6064
6065
6066
6067
6068
6069
6070
6071
6072
6073
    }
    
    // Build data structures for exclusions.
    
    if (force.getNumExclusions() > 0) {
        vector<vector<int> > particleExclusions(numParticles);
        for (int i = 0; i < force.getNumExclusions(); i++) {
            int p1, p2;
            force.getExclusionParticles(i, p1, p2);
            particleExclusions[p1].push_back(p2);
            particleExclusions[p2].push_back(p1);
        }
        vector<int> exclusionsVec;
        vector<int> exclusionStartIndexVec(numParticles+1);
        exclusionStartIndexVec[0] = 0;
        for (int i = 0; i < numParticles; i++) {
            sort(particleExclusions[i].begin(), particleExclusions[i].end());
            exclusionsVec.insert(exclusionsVec.end(), particleExclusions[i].begin(), particleExclusions[i].end());
            exclusionStartIndexVec[i+1] = exclusionsVec.size();
        }
peastman's avatar
peastman committed
6074
6075
6076
6077
        exclusions.initialize<int>(cl, exclusionsVec.size(), "customManyParticleExclusions");
        exclusionStartIndex.initialize<int>(cl, exclusionStartIndexVec.size(), "customManyParticleExclusionStart");
        exclusions.upload(exclusionsVec);
        exclusionStartIndex.upload(exclusionStartIndexVec);
6078
6079
6080
6081
6082
6083
6084
    }
    
    // Build data structures for the neighbor list.
    
    if (nonbondedMethod != NoCutoff) {
        int numAtomBlocks = cl.getNumAtomBlocks();
        int elementSize = (cl.getUseDoublePrecision() ? sizeof(double) : sizeof(float));
peastman's avatar
peastman committed
6085
6086
6087
6088
6089
        blockCenter.initialize(cl, numAtomBlocks, 4*elementSize, "blockCenter");
        blockBoundingBox.initialize(cl, numAtomBlocks, 4*elementSize, "blockBoundingBox");
        numNeighborPairs.initialize<int>(cl, 1, "customManyParticleNumNeighborPairs");
        neighborStartIndex.initialize<int>(cl, numParticles+1, "customManyParticleNeighborStartIndex");
        numNeighborsForAtom.initialize<int>(cl, numParticles, "customManyParticleNumNeighborsForAtom");
6090
6091
6092
6093
6094

        // Select a size for the array that holds the neighbor list.  We have to make a fairly
        // arbitrary guess, but if this turns out to be too small we'll increase it later.

        maxNeighborPairs = 150*numParticles;
peastman's avatar
peastman committed
6095
6096
        neighborPairs.initialize<mm_int2>(cl, maxNeighborPairs, "customManyParticleNeighborPairs");
        neighbors.initialize<int>(cl, maxNeighborPairs, "customManyParticleNeighbors");
6097
6098
6099
6100
6101
6102
6103
6104
6105
6106
6107
6108
6109
6110
6111
6112
6113
6114
6115
    }

    // Now to generate the kernel.  First, it needs to calculate all distances, angles,
    // and dihedrals the expression depends on.

    map<string, vector<int> > distances;
    map<string, vector<int> > angles;
    map<string, vector<int> > dihedrals;
    Lepton::ParsedExpression energyExpression = CustomManyParticleForceImpl::prepareExpression(force, functions, distances, angles, dihedrals);
    map<string, Lepton::ParsedExpression> forceExpressions;
    set<string> computedDeltas;
    vector<string> atomNames, posNames;
    for (int i = 0; i < particlesPerSet; i++) {
        string index = cl.intToString(i+1);
        atomNames.push_back("P"+index);
        posNames.push_back("pos"+index);
    }
    stringstream compute;
    int index = 0;
peastman's avatar
peastman committed
6116
6117
    for (auto& distance : distances) {
        const vector<int>& atoms = distance.second;
6118
6119
        string deltaName = atomNames[atoms[0]]+atomNames[atoms[1]];
        if (computedDeltas.count(deltaName) == 0) {
6120
            compute<<"real4 delta"<<deltaName<<" = delta("<<posNames[atoms[0]]<<", "<<posNames[atoms[1]]<<", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);\n";
6121
6122
6123
            computedDeltas.insert(deltaName);
        }
        compute<<"real r_"<<deltaName<<" = sqrt(delta"<<deltaName<<".w);\n";
peastman's avatar
peastman committed
6124
6125
6126
        variables.push_back(makeVariable(distance.first, "r_"+deltaName));
        forceExpressions["real dEdDistance"+cl.intToString(index)+" = "] = energyExpression.differentiate(distance.first).optimize();
        index++;
6127
6128
    }
    index = 0;
peastman's avatar
peastman committed
6129
6130
    for (auto& angle : angles) {
        const vector<int>& atoms = angle.second;
6131
6132
6133
6134
        string deltaName1 = atomNames[atoms[1]]+atomNames[atoms[0]];
        string deltaName2 = atomNames[atoms[1]]+atomNames[atoms[2]];
        string angleName = "angle_"+atomNames[atoms[0]]+atomNames[atoms[1]]+atomNames[atoms[2]];
        if (computedDeltas.count(deltaName1) == 0) {
6135
            compute<<"real4 delta"<<deltaName1<<" = delta("<<posNames[atoms[1]]<<", "<<posNames[atoms[0]]<<", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);\n";
6136
6137
6138
            computedDeltas.insert(deltaName1);
        }
        if (computedDeltas.count(deltaName2) == 0) {
6139
            compute<<"real4 delta"<<deltaName2<<" = delta("<<posNames[atoms[1]]<<", "<<posNames[atoms[2]]<<", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);\n";
6140
6141
6142
            computedDeltas.insert(deltaName2);
        }
        compute<<"real "<<angleName<<" = computeAngle(delta"<<deltaName1<<", delta"<<deltaName2<<");\n";
peastman's avatar
peastman committed
6143
6144
6145
        variables.push_back(makeVariable(angle.first, angleName));
        forceExpressions["real dEdAngle"+cl.intToString(index)+" = "] = energyExpression.differentiate(angle.first).optimize();
        index++;
6146
6147
    }
    index = 0;
peastman's avatar
peastman committed
6148
6149
    for (auto& dihedral : dihedrals) {
        const vector<int>& atoms = dihedral.second;
6150
6151
6152
6153
6154
6155
6156
        string deltaName1 = atomNames[atoms[0]]+atomNames[atoms[1]];
        string deltaName2 = atomNames[atoms[2]]+atomNames[atoms[1]];
        string deltaName3 = atomNames[atoms[2]]+atomNames[atoms[3]];
        string crossName1 = "cross_"+deltaName1+"_"+deltaName2;
        string crossName2 = "cross_"+deltaName2+"_"+deltaName3;
        string dihedralName = "dihedral_"+atomNames[atoms[0]]+atomNames[atoms[1]]+atomNames[atoms[2]]+atomNames[atoms[3]];
        if (computedDeltas.count(deltaName1) == 0) {
6157
            compute<<"real4 delta"<<deltaName1<<" = delta("<<posNames[atoms[0]]<<", "<<posNames[atoms[1]]<<", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);\n";
6158
6159
6160
            computedDeltas.insert(deltaName1);
        }
        if (computedDeltas.count(deltaName2) == 0) {
6161
            compute<<"real4 delta"<<deltaName2<<" = delta("<<posNames[atoms[2]]<<", "<<posNames[atoms[1]]<<", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);\n";
6162
6163
6164
            computedDeltas.insert(deltaName2);
        }
        if (computedDeltas.count(deltaName3) == 0) {
6165
            compute<<"real4 delta"<<deltaName3<<" = delta("<<posNames[atoms[2]]<<", "<<posNames[atoms[3]]<<", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);\n";
6166
6167
6168
6169
6170
6171
            computedDeltas.insert(deltaName3);
        }
        compute<<"real4 "<<crossName1<<" = computeCross(delta"<<deltaName1<<", delta"<<deltaName2<<");\n";
        compute<<"real4 "<<crossName2<<" = computeCross(delta"<<deltaName2<<", delta"<<deltaName3<<");\n";
        compute<<"real "<<dihedralName<<" = computeAngle("<<crossName1<<", "<<crossName2<<");\n";
        compute<<dihedralName<<" *= (delta"<<deltaName1<<".x*"<<crossName2<<".x + delta"<<deltaName1<<".y*"<<crossName2<<".y + delta"<<deltaName1<<".z*"<<crossName2<<".z < 0 ? -1 : 1);\n";
peastman's avatar
peastman committed
6172
6173
6174
        variables.push_back(makeVariable(dihedral.first, dihedralName));
        forceExpressions["real dEdDihedral"+cl.intToString(index)+" = "] = energyExpression.differentiate(dihedral.first).optimize();
        index++;
6175
6176
6177
6178
6179
6180
6181
6182
6183
6184
6185
6186
6187
6188
6189
6190
6191
6192
6193
6194
6195
6196
6197
6198
6199
6200
6201
6202
6203
6204
6205
6206
6207
6208
6209
    }

    // Now evaluate the expressions.

    for (int i = 0; i < (int) params->getBuffers().size(); i++) {
        OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i];
        compute<<buffer.getType()<<" params"<<(i+1)<<" = global_params"<<(i+1)<<"[index];\n";
    }
    forceExpressions["energy += "] = energyExpression;
    compute << cl.getExpressionUtilities().createExpressions(forceExpressions, variables, functionList, functionDefinitions, "temp");

    // Apply forces to atoms.

    vector<string> forceNames;
    for (int i = 0; i < particlesPerSet; i++) {
        string istr = cl.intToString(i+1);
        string forceName = "force"+istr;
        forceNames.push_back(forceName);
        compute<<"real4 "<<forceName<<" = (real4) 0;\n";
        compute<<"{\n";
        Lepton::ParsedExpression forceExpressionX = energyExpression.differentiate("x"+istr).optimize();
        Lepton::ParsedExpression forceExpressionY = energyExpression.differentiate("y"+istr).optimize();
        Lepton::ParsedExpression forceExpressionZ = energyExpression.differentiate("z"+istr).optimize();
        map<string, Lepton::ParsedExpression> expressions;
        if (!isZeroExpression(forceExpressionX))
            expressions[forceName+".x -= "] = forceExpressionX;
        if (!isZeroExpression(forceExpressionY))
            expressions[forceName+".y -= "] = forceExpressionY;
        if (!isZeroExpression(forceExpressionZ))
            expressions[forceName+".z -= "] = forceExpressionZ;
        if (expressions.size() > 0)
            compute<<cl.getExpressionUtilities().createExpressions(expressions, variables, functionList, functionDefinitions, "coordtemp");
        compute<<"}\n";
    }
    index = 0;
peastman's avatar
peastman committed
6210
6211
    for (auto& distance : distances) {
        const vector<int>& atoms = distance.second;
6212
6213
6214
6215
        string deltaName = atomNames[atoms[0]]+atomNames[atoms[1]];
        string value = "(dEdDistance"+cl.intToString(index)+"/r_"+deltaName+")*delta"+deltaName+".xyz";
        compute<<forceNames[atoms[0]]<<".xyz += "<<"-"<<value<<";\n";
        compute<<forceNames[atoms[1]]<<".xyz += "<<value<<";\n";
peastman's avatar
peastman committed
6216
        index++;
6217
6218
    }
    index = 0;
peastman's avatar
peastman committed
6219
6220
    for (auto& angle : angles) {
        const vector<int>& atoms = angle.second;
6221
6222
6223
6224
        string deltaName1 = atomNames[atoms[1]]+atomNames[atoms[0]];
        string deltaName2 = atomNames[atoms[1]]+atomNames[atoms[2]];
        compute<<"{\n";
        compute<<"real4 crossProd = cross(delta"<<deltaName2<<", delta"<<deltaName1<<");\n";
6225
        compute<<"real lengthCross = max(SQRT(dot(crossProd, crossProd)), (real) 1e-6f);\n";
6226
6227
6228
6229
6230
6231
6232
        compute<<"real4 deltaCross0 = -cross(delta"<<deltaName1<<", crossProd)*dEdAngle"<<cl.intToString(index)<<"/(delta"<<deltaName1<<".w*lengthCross);\n";
        compute<<"real4 deltaCross2 = cross(delta"<<deltaName2<<", crossProd)*dEdAngle"<<cl.intToString(index)<<"/(delta"<<deltaName2<<".w*lengthCross);\n";
        compute<<"real4 deltaCross1 = -(deltaCross0+deltaCross2);\n";
        compute<<forceNames[atoms[0]]<<".xyz += deltaCross0.xyz;\n";
        compute<<forceNames[atoms[1]]<<".xyz += deltaCross1.xyz;\n";
        compute<<forceNames[atoms[2]]<<".xyz += deltaCross2.xyz;\n";
        compute<<"}\n";
peastman's avatar
peastman committed
6233
        index++;
6234
6235
    }
    index = 0;
peastman's avatar
peastman committed
6236
6237
    for (auto& dihedral : dihedrals) {
        const vector<int>& atoms = dihedral.second;
6238
6239
6240
6241
6242
6243
6244
6245
6246
6247
6248
6249
6250
6251
6252
6253
6254
6255
6256
6257
        string deltaName1 = atomNames[atoms[0]]+atomNames[atoms[1]];
        string deltaName2 = atomNames[atoms[2]]+atomNames[atoms[1]];
        string deltaName3 = atomNames[atoms[2]]+atomNames[atoms[3]];
        string crossName1 = "cross_"+deltaName1+"_"+deltaName2;
        string crossName2 = "cross_"+deltaName2+"_"+deltaName3;
        compute<<"{\n";
        compute<<"real r = sqrt(delta"<<deltaName2<<".w);\n";
        compute<<"real4 ff;\n";
        compute<<"ff.x = (-dEdDihedral"<<cl.intToString(index)<<"*r)/"<<crossName1<<".w;\n";
        compute<<"ff.y = (delta"<<deltaName1<<".x*delta"<<deltaName2<<".x + delta"<<deltaName1<<".y*delta"<<deltaName2<<".y + delta"<<deltaName1<<".z*delta"<<deltaName2<<".z)/delta"<<deltaName2<<".w;\n";
        compute<<"ff.z = (delta"<<deltaName3<<".x*delta"<<deltaName2<<".x + delta"<<deltaName3<<".y*delta"<<deltaName2<<".y + delta"<<deltaName3<<".z*delta"<<deltaName2<<".z)/delta"<<deltaName2<<".w;\n";
        compute<<"ff.w = (dEdDihedral"<<cl.intToString(index)<<"*r)/"<<crossName2<<".w;\n";
        compute<<"real4 internalF0 = ff.x*"<<crossName1<<";\n";
        compute<<"real4 internalF3 = ff.w*"<<crossName2<<";\n";
        compute<<"real4 s = ff.y*internalF0 - ff.z*internalF3;\n";
        compute<<forceNames[atoms[0]]<<".xyz += internalF0.xyz;\n";
        compute<<forceNames[atoms[1]]<<".xyz += s.xyz-internalF0.xyz;\n";
        compute<<forceNames[atoms[2]]<<".xyz += -s.xyz-internalF3.xyz;\n";
        compute<<forceNames[atoms[3]]<<".xyz += internalF3.xyz;\n";
        compute<<"}\n";
peastman's avatar
peastman committed
6258
        index++;
6259
6260
6261
6262
6263
6264
6265
6266
6267
6268
6269
6270
6271
6272
6273
6274
6275
6276
6277
6278
6279
    }
    
    // Store forces to global memory.
    
    for (int i = 0; i < particlesPerSet; i++)
        compute<<"storeForce(atom"<<(i+1)<<", "<<forceNames[i]<<", forceBuffers);\n";
    
    // Create other replacements that depend on the number of particles per set.
    
    stringstream numCombinations, atomsForCombination, isValidCombination, permute, loadData, verifyCutoff, verifyExclusions;
    if (hasTypeFilters) {
        permute<<"int particleSet[] = {";
        for (int i = 0; i < particlesPerSet; i++) {
            permute<<"p"<<(i+1);
            if (i < particlesPerSet-1)
                permute<<", ";
        }
        permute<<"};\n";
    }
    for (int i = 0; i < particlesPerSet; i++) {
        if (hasTypeFilters)
peastman's avatar
Bug fix  
peastman committed
6280
            permute<<"int atom"<<(i+1)<<" = particleSet[particleOrder["<<particlesPerSet<<"*order+"<<i<<"]];\n";
6281
6282
6283
6284
6285
6286
6287
6288
6289
6290
6291
6292
6293
6294
6295
6296
6297
6298
6299
6300
6301
6302
6303
6304
6305
6306
6307
6308
6309
6310
6311
6312
6313
6314
6315
6316
6317
6318
6319
6320
6321
6322
6323
6324
6325
6326
6327
6328
6329
6330
6331
6332
6333
6334
6335
6336
6337
6338
        else
            permute<<"int atom"<<(i+1)<<" = p"<<(i+1)<<";\n";
        loadData<<"real4 pos"<<(i+1)<<" = posq[atom"<<(i+1)<<"];\n";
        for (int j = 0; j < (int) params->getBuffers().size(); j++)
            loadData<<params->getBuffers()[j].getType()<<" params"<<(j+1)<<(i+1)<<" = global_params"<<(j+1)<<"[atom"<<(i+1)<<"];\n";
    }
    if (centralParticleMode) {
        for (int i = 1; i < particlesPerSet; i++) {
            if (i > 1)
                isValidCombination<<" && p"<<(i+1)<<">p"<<i<<" && ";
            isValidCombination<<"p"<<(i+1)<<"!=p1";
        }
    }
    else {
        for (int i = 2; i < particlesPerSet; i++) {
            if (i > 2)
                isValidCombination<<" && ";
            isValidCombination<<"a"<<(i+1)<<">a"<<i;
        }
    }
    atomsForCombination<<"int tempIndex = index;\n";
    for (int i = 1; i < particlesPerSet; i++) {
        if (i > 1)
            numCombinations<<"*";
        numCombinations<<"numNeighbors";
        if (centralParticleMode)
            atomsForCombination<<"int a"<<(i+1)<<" = tempIndex%numNeighbors;\n";
        else
            atomsForCombination<<"int a"<<(i+1)<<" = 1+tempIndex%numNeighbors;\n";
        if (i < particlesPerSet-1)
            atomsForCombination<<"tempIndex /= numNeighbors;\n";
    }
    if (particlesPerSet > 2) {
        if (centralParticleMode)
            atomsForCombination<<"a2 = (a3%2 == 0 ? a2 : numNeighbors-a2-1);\n";
        else
            atomsForCombination<<"a2 = (a3%2 == 0 ? a2 : numNeighbors-a2+1);\n";
    }
    for (int i = 1; i < particlesPerSet; i++) {
        if (nonbondedMethod == NoCutoff) {
            if (centralParticleMode)
                atomsForCombination<<"int p"<<(i+1)<<" = a"<<(i+1)<<";\n";
            else
                atomsForCombination<<"int p"<<(i+1)<<" = p1+a"<<(i+1)<<";\n";
        }
        else {
            if (centralParticleMode)
                atomsForCombination<<"int p"<<(i+1)<<" = neighbors[firstNeighbor+a"<<(i+1)<<"];\n";
            else
                atomsForCombination<<"int p"<<(i+1)<<" = neighbors[firstNeighbor-1+a"<<(i+1)<<"];\n";
        }
    }
    if (nonbondedMethod != NoCutoff) {
        for (int i = 1; i < particlesPerSet; i++)
            verifyCutoff<<"real4 pos"<<(i+1)<<" = posq[p"<<(i+1)<<"];\n";
        if (!centralParticleMode) {
            for (int i = 1; i < particlesPerSet; i++) {
                for (int j = i+1; j < particlesPerSet; j++)
6339
                    verifyCutoff<<"includeInteraction &= (delta(pos"<<(i+1)<<", pos"<<(j+1)<<", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ).w < CUTOFF_SQUARED);\n";
6340
6341
6342
6343
6344
6345
6346
6347
6348
6349
6350
6351
6352
6353
6354
6355
6356
6357
6358
6359
6360
6361
6362
6363
6364
6365
6366
6367
6368
6369
6370
6371
6372
6373
6374
6375
6376
6377
6378
6379
6380
6381
6382
6383
6384
6385
6386
6387
6388
6389
6390
6391
6392
6393
6394
6395
6396
6397
6398
6399
6400
6401
6402
6403
6404
6405
6406
6407
6408
            }
        }
    }
    if (force.getNumExclusions() > 0) {
        int startCheckFrom = (nonbondedMethod == NoCutoff ? 0 : 1);
        for (int i = startCheckFrom; i < particlesPerSet; i++)
            for (int j = i+1; j < particlesPerSet; j++)
                verifyExclusions<<"includeInteraction &= !isInteractionExcluded(p"<<(i+1)<<", p"<<(j+1)<<", exclusions, exclusionStartIndex);\n";
    }
    string computeTypeIndex = "particleTypes[p"+cl.intToString(particlesPerSet)+"]";
    for (int i = particlesPerSet-2; i >= 0; i--)
        computeTypeIndex = "particleTypes[p"+cl.intToString(i+1)+"]+"+cl.intToString(numTypes)+"*("+computeTypeIndex+")";
    
    // Create replacements for extra arguments.
    
    stringstream extraArgs;
    if (force.getNumGlobalParameters() > 0)
        extraArgs << ", __global const float* globals";
    for (int i = 0; i < (int) params->getBuffers().size(); i++) {
        OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i];
        extraArgs<<", __global const "<<buffer.getType()<<"* restrict global_params"<<(i+1);
    }

    // Create the kernels.

    map<string, string> replacements;
    replacements["COMPUTE_INTERACTION"] = compute.str();
    replacements["NUM_CANDIDATE_COMBINATIONS"] = numCombinations.str();
    replacements["FIND_ATOMS_FOR_COMBINATION_INDEX"] = atomsForCombination.str();
    replacements["IS_VALID_COMBINATION"] = isValidCombination.str();
    replacements["VERIFY_CUTOFF"] = verifyCutoff.str();
    replacements["VERIFY_EXCLUSIONS"] = verifyExclusions.str();
    replacements["PERMUTE_ATOMS"] = permute.str();
    replacements["LOAD_PARTICLE_DATA"] = loadData.str();
    replacements["COMPUTE_TYPE_INDEX"] = computeTypeIndex;
    replacements["PARAMETER_ARGUMENTS"] = extraArgs.str()+tableArgs.str();
    map<string, string> defines;
    if (nonbondedMethod != NoCutoff)
        defines["USE_CUTOFF"] = "1";
    if (nonbondedMethod == CutoffPeriodic)
        defines["USE_PERIODIC"] = "1";
    if (centralParticleMode)
        defines["USE_CENTRAL_PARTICLE"] = "1";
    if (hasTypeFilters)
        defines["USE_FILTERS"] = "1";
    if (force.getNumExclusions() > 0)
        defines["USE_EXCLUSIONS"] = "1";
    defines["NUM_ATOMS"] = cl.intToString(cl.getNumAtoms());
    defines["PADDED_NUM_ATOMS"] = cl.intToString(cl.getPaddedNumAtoms());
    defines["M_PI"] = cl.doubleToString(M_PI);
    defines["CUTOFF_SQUARED"] = cl.doubleToString(force.getCutoffDistance()*force.getCutoffDistance());
    defines["TILE_SIZE"] = cl.intToString(OpenCLContext::TileSize);
    defines["NUM_BLOCKS"] = cl.intToString(cl.getNumAtomBlocks());
    defines["FIND_NEIGHBORS_WORKGROUP_SIZE"] = cl.intToString(findNeighborsWorkgroupSize);
    cl::Program program = cl.createProgram(cl.replaceStrings(OpenCLKernelSources::customManyParticle, replacements), defines);
    forceKernel = cl::Kernel(program, "computeInteraction");
    blockBoundsKernel = cl::Kernel(program, "findBlockBounds");
    neighborsKernel = cl::Kernel(program, "findNeighbors");
    startIndicesKernel = cl::Kernel(program, "computeNeighborStartIndices");
    copyPairsKernel = cl::Kernel(program, "copyPairsToNeighborList");
}

double OpenCLCalcCustomManyParticleForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
    if (!hasInitializedKernel) {
        hasInitializedKernel = true;
        
        // Set arguments for the force kernel.
        
        int index = 0;
6409
        forceKernel.setArg<cl::Buffer>(index++, cl.getLongForceBuffer().getDeviceBuffer());
6410
6411
        forceKernel.setArg<cl::Buffer>(index++, cl.getEnergyBuffer().getDeviceBuffer());
        forceKernel.setArg<cl::Buffer>(index++, cl.getPosq().getDeviceBuffer());
6412
6413
        setPeriodicBoxArgs(cl, forceKernel, index);
        index += 5;
6414
        if (nonbondedMethod != NoCutoff) {
peastman's avatar
peastman committed
6415
6416
            forceKernel.setArg<cl::Buffer>(index++, neighbors.getDeviceBuffer());
            forceKernel.setArg<cl::Buffer>(index++, neighborStartIndex.getDeviceBuffer());
6417
        }
peastman's avatar
peastman committed
6418
6419
6420
6421
        if (particleTypes.isInitialized()) {
            forceKernel.setArg<cl::Buffer>(index++, particleTypes.getDeviceBuffer());
            forceKernel.setArg<cl::Buffer>(index++, orderIndex.getDeviceBuffer());
            forceKernel.setArg<cl::Buffer>(index++, particleOrder.getDeviceBuffer());
6422
        }
peastman's avatar
peastman committed
6423
6424
6425
        if (exclusions.isInitialized()) {
            forceKernel.setArg<cl::Buffer>(index++, exclusions.getDeviceBuffer());
            forceKernel.setArg<cl::Buffer>(index++, exclusionStartIndex.getDeviceBuffer());
6426
        }
peastman's avatar
peastman committed
6427
6428
        if (globals.isInitialized())
            forceKernel.setArg<cl::Buffer>(index++, globals.getDeviceBuffer());
peastman's avatar
peastman committed
6429
        for (auto& buffer : params->getBuffers())
6430
            forceKernel.setArg<cl::Memory>(index++, buffer.getMemory());
peastman's avatar
peastman committed
6431
6432
        for (auto& function : tabulatedFunctions)
            forceKernel.setArg<cl::Buffer>(index++, function.getDeviceBuffer());
6433
6434
6435
6436
6437
        
        if (nonbondedMethod != NoCutoff) {
            // Set arguments for the block bounds kernel.

            index = 0;
6438
6439
            setPeriodicBoxArgs(cl, blockBoundsKernel, index);
            index += 5;
6440
            blockBoundsKernel.setArg<cl::Buffer>(index++, cl.getPosq().getDeviceBuffer());
peastman's avatar
peastman committed
6441
6442
6443
            blockBoundsKernel.setArg<cl::Buffer>(index++, blockCenter.getDeviceBuffer());
            blockBoundsKernel.setArg<cl::Buffer>(index++, blockBoundingBox.getDeviceBuffer());
            blockBoundsKernel.setArg<cl::Buffer>(index++, numNeighborPairs.getDeviceBuffer());
6444
6445
6446
6447

            // Set arguments for the neighbor list kernel.

            index = 0;
6448
6449
            setPeriodicBoxArgs(cl, neighborsKernel, index);
            index += 5;
6450
            neighborsKernel.setArg<cl::Buffer>(index++, cl.getPosq().getDeviceBuffer());
peastman's avatar
peastman committed
6451
6452
6453
6454
6455
            neighborsKernel.setArg<cl::Buffer>(index++, blockCenter.getDeviceBuffer());
            neighborsKernel.setArg<cl::Buffer>(index++, blockBoundingBox.getDeviceBuffer());
            neighborsKernel.setArg<cl::Buffer>(index++, neighborPairs.getDeviceBuffer());
            neighborsKernel.setArg<cl::Buffer>(index++, numNeighborPairs.getDeviceBuffer());
            neighborsKernel.setArg<cl::Buffer>(index++, numNeighborsForAtom.getDeviceBuffer());
6456
            index++;
peastman's avatar
peastman committed
6457
6458
6459
            if (exclusions.isInitialized()) {
                neighborsKernel.setArg<cl::Buffer>(index++, exclusions.getDeviceBuffer());
                neighborsKernel.setArg<cl::Buffer>(index++, exclusionStartIndex.getDeviceBuffer());
6460
6461
6462
6463
6464
            }
            
            // Set arguments for the kernel to find neighbor list start indices.
            
            index = 0;
peastman's avatar
peastman committed
6465
6466
6467
            startIndicesKernel.setArg<cl::Buffer>(index++, numNeighborsForAtom.getDeviceBuffer());
            startIndicesKernel.setArg<cl::Buffer>(index++, neighborStartIndex.getDeviceBuffer());
            startIndicesKernel.setArg<cl::Buffer>(index++, numNeighborPairs.getDeviceBuffer());
6468
6469
6470
6471

            // Set arguments for the kernel to assemble the final neighbor list.
            
            index = 0;
peastman's avatar
peastman committed
6472
6473
6474
            copyPairsKernel.setArg<cl::Buffer>(index++, neighborPairs.getDeviceBuffer());
            copyPairsKernel.setArg<cl::Buffer>(index++, neighbors.getDeviceBuffer());
            copyPairsKernel.setArg<cl::Buffer>(index++, numNeighborPairs.getDeviceBuffer());
6475
            index++;
peastman's avatar
peastman committed
6476
6477
            copyPairsKernel.setArg<cl::Buffer>(index++, numNeighborsForAtom.getDeviceBuffer());
            copyPairsKernel.setArg<cl::Buffer>(index++, neighborStartIndex.getDeviceBuffer());
6478
6479
       }
    }
peastman's avatar
peastman committed
6480
    if (globals.isInitialized()) {
6481
6482
6483
6484
6485
6486
6487
6488
        bool changed = false;
        for (int i = 0; i < (int) globalParamNames.size(); i++) {
            cl_float value = (cl_float) context.getParameter(globalParamNames[i]);
            if (value != globalParamValues[i])
                changed = true;
            globalParamValues[i] = value;
        }
        if (changed)
peastman's avatar
peastman committed
6489
            globals.upload(globalParamValues);
6490
6491
6492
6493
6494
    }
    while (true) {
        int* numPairs = (int*) cl.getPinnedBuffer();
        cl::Event event;
        if (nonbondedMethod != NoCutoff) {
6495
            neighborsKernel.setArg<int>(11, maxNeighborPairs);
6496
6497
6498
6499
6500
6501
6502
6503
            startIndicesKernel.setArg<int>(3, maxNeighborPairs);
            copyPairsKernel.setArg<int>(3, maxNeighborPairs);
            cl.executeKernel(blockBoundsKernel, cl.getNumAtomBlocks());
            cl.executeKernel(neighborsKernel, cl.getNumAtoms(), findNeighborsWorkgroupSize);

            // We need to make sure there was enough memory for the neighbor list.  Download the
            // information asynchronously so kernels can be running at the same time.

peastman's avatar
peastman committed
6504
            numNeighborPairs.download(numPairs, false);
6505
6506
6507
6508
            cl.getQueue().enqueueMarker(&event);
            cl.executeKernel(startIndicesKernel, 256, 256);
            cl.executeKernel(copyPairsKernel, maxNeighborPairs);
        }
6509
6510
        int maxThreads = min(cl.getNumAtoms()*forceWorkgroupSize, cl.getEnergyBuffer().getSize());
        cl.executeKernel(forceKernel, maxThreads, forceWorkgroupSize);
6511
6512
6513
6514
6515
6516
6517
6518
        if (nonbondedMethod != NoCutoff) {
            // Make sure there was enough memory for the neighbor list.

            event.wait();
            if (*numPairs > maxNeighborPairs) {
                // Resize the arrays and run the calculation again.

                maxNeighborPairs = (int) (1.1*(*numPairs));
peastman's avatar
peastman committed
6519
6520
6521
6522
6523
6524
                neighborPairs.resize(maxNeighborPairs);
                neighbors.resize(maxNeighborPairs);
                forceKernel.setArg<cl::Buffer>(8, neighbors.getDeviceBuffer());
                neighborsKernel.setArg<cl::Buffer>(8, neighborPairs.getDeviceBuffer());
                copyPairsKernel.setArg<cl::Buffer>(0, neighborPairs.getDeviceBuffer());
                copyPairsKernel.setArg<cl::Buffer>(1, neighbors.getDeviceBuffer());
6525
6526
6527
6528
6529
6530
6531
6532
6533
6534
6535
6536
6537
6538
6539
6540
6541
6542
6543
6544
6545
6546
6547
6548
6549
6550
6551
6552
                continue;
            }
        }
        break;
    }
    return 0.0;
}

void OpenCLCalcCustomManyParticleForceKernel::copyParametersToContext(ContextImpl& context, const CustomManyParticleForce& force) {
    int numParticles = force.getNumParticles();
    if (numParticles != cl.getNumAtoms())
        throw OpenMMException("updateParametersInContext: The number of particles has changed");
    
    // Record the per-particle parameters.
    
    vector<vector<float> > paramVector(numParticles);
    vector<double> parameters;
    int type;
    for (int i = 0; i < numParticles; i++) {
        force.getParticleParameters(i, parameters, type);
        paramVector[i].resize(parameters.size());
        for (int j = 0; j < (int) parameters.size(); j++)
            paramVector[i][j] = (float) parameters[j];
    }
    params->setParameterValues(paramVector);
    
    // Mark that the current reordering may be invalid.
    
6553
    cl.invalidateMolecules(info);
6554
6555
}

6556
class OpenCLCalcGayBerneForceKernel::ForceInfo : public OpenCLForceInfo {
6557
public:
6558
    ForceInfo(int requiredBuffers, const GayBerneForce& force) : OpenCLForceInfo(requiredBuffers), force(force) {
6559
6560
6561
6562
6563
6564
6565
6566
6567
6568
6569
6570
6571
6572
6573
6574
6575
6576
6577
6578
6579
6580
6581
6582
6583
6584
6585
6586
6587
6588
6589
6590
6591
6592
6593
6594
6595
6596
6597
6598
6599
6600
6601
6602
6603
6604
6605
6606
6607
6608
6609
6610
6611
6612
6613
6614
6615
6616
6617
6618
6619
    }
    bool areParticlesIdentical(int particle1, int particle2) {
        int xparticle1, yparticle1;
        double sigma1, epsilon1, sx1, sy1, sz1, ex1, ey1, ez1;
        int xparticle2, yparticle2;
        double sigma2, epsilon2, sx2, sy2, sz2, ex2, ey2, ez2;
        force.getParticleParameters(particle1, sigma1, epsilon1, xparticle1, yparticle1, sx1, sy1, sz1, ex1, ey1, ez1);
        force.getParticleParameters(particle2, sigma2, epsilon2, xparticle2, yparticle2, sx2, sy2, sz2, ex2, ey2, ez2);
        return (sigma1 == sigma2 && epsilon1 == epsilon2 && sx1 == sx2 && sy1 == sy2 && sz1 == sz2 && ex1 == ex2 && ey1 == ey2 && ez1 == ez2);
    }
    int getNumParticleGroups() {
        return force.getNumExceptions()+force.getNumParticles();
    }
    void getParticlesInGroup(int index, vector<int>& particles) {
        if (index < force.getNumExceptions()) {
            int particle1, particle2;
            double sigma, epsilon;
            force.getExceptionParameters(index, particle1, particle2, sigma, epsilon);
            particles.resize(2);
            particles[0] = particle1;
            particles[1] = particle2;
        }
        else {
            int particle = index-force.getNumExceptions();
            int xparticle, yparticle;
            double sigma, epsilon, sx, sy, sz, ex, ey, ez;
            force.getParticleParameters(particle, sigma, epsilon, xparticle, yparticle, sx, sy, sz, ex, ey, ez);
            particles.clear();
            particles.push_back(particle);
            if (xparticle > -1)
                particles.push_back(xparticle);
            if (yparticle > -1)
                particles.push_back(yparticle);
        }
    }
    bool areGroupsIdentical(int group1, int group2) {
        if (group1 < force.getNumExceptions() && group2 < force.getNumExceptions()) {
            int particle1, particle2;
            double sigma1, sigma2, epsilon1, epsilon2;
            force.getExceptionParameters(group1, particle1, particle2, sigma1, epsilon1);
            force.getExceptionParameters(group2, particle1, particle2, sigma2, epsilon2);
            return (sigma1 == sigma2 && epsilon1 == epsilon2);
        }
        return true;
    }
private:
    const GayBerneForce& force;
};

class OpenCLCalcGayBerneForceKernel::ReorderListener : public OpenCLContext::ReorderListener {
public:
    ReorderListener(OpenCLCalcGayBerneForceKernel& owner) : owner(owner) {
    }
    void execute() {
        owner.sortAtoms();
    }
private:
    OpenCLCalcGayBerneForceKernel& owner;
};

void OpenCLCalcGayBerneForceKernel::initialize(const System& system, const GayBerneForce& force) {
6620
6621
    if (!cl.getSupports64BitGlobalAtomics())
        throw OpenMMException("GayBerneForce requires a device that supports 64 bit atomic operations");
6622
6623
6624
6625

    // Initialize interactions.

    int numParticles = force.getNumParticles();
peastman's avatar
peastman committed
6626
6627
6628
6629
6630
6631
6632
6633
    sigParams.initialize<mm_float4>(cl, cl.getPaddedNumAtoms(), "sigParams");
    epsParams.initialize<mm_float2>(cl, cl.getPaddedNumAtoms(), "epsParams");
    scale.initialize<mm_float4>(cl, cl.getPaddedNumAtoms(), "scale");
    axisParticleIndices.initialize<mm_int2>(cl, cl.getPaddedNumAtoms(), "axisParticleIndices");
    sortedParticles.initialize<cl_int>(cl, cl.getPaddedNumAtoms(), "sortedParticles");
    aMatrix.initialize<cl_float>(cl, 9*cl.getPaddedNumAtoms(), "aMatrix");
    bMatrix.initialize<cl_float>(cl, 9*cl.getPaddedNumAtoms(), "bMatrix");
    gMatrix.initialize<cl_float>(cl, 9*cl.getPaddedNumAtoms(), "gMatrix");
6634
6635
6636
6637
6638
6639
6640
6641
6642
6643
6644
6645
6646
    vector<mm_float4> sigParamsVector(cl.getPaddedNumAtoms(), mm_float4(0, 0, 0, 0));
    vector<mm_float2> epsParamsVector(cl.getPaddedNumAtoms(), mm_float2(0, 0));
    vector<mm_float4> scaleVector(cl.getPaddedNumAtoms(), mm_float4(0, 0, 0, 0));
    vector<mm_int2> axisParticleVector(cl.getPaddedNumAtoms(), mm_int2(0, 0));
    isRealParticle.resize(cl.getPaddedNumAtoms());
    for (int i = 0; i < numParticles; i++) {
        int xparticle, yparticle;
        double sigma, epsilon, sx, sy, sz, ex, ey, ez;
        force.getParticleParameters(i, sigma, epsilon, xparticle, yparticle, sx, sy, sz, ex, ey, ez);
        axisParticleVector[i] = mm_int2(xparticle, yparticle);
        sigParamsVector[i] = mm_float4((float) (0.5*sigma), (float) (0.25*sx*sx), (float) (0.25*sy*sy), (float) (0.25*sz*sz));
        epsParamsVector[i] = mm_float2((float) sqrt(epsilon), (float) (0.125*(sx*sy + sz*sz)*sqrt(sx*sy)));
        scaleVector[i] = mm_float4((float) (1/sqrt(ex)), (float) (1/sqrt(ey)), (float) (1/sqrt(ez)), 0);
6647
        isRealParticle[i] = (epsilon != 0.0);
6648
    }
peastman's avatar
peastman committed
6649
6650
6651
6652
    sigParams.upload(sigParamsVector);
    epsParams.upload(epsParamsVector);
    scale.upload(scaleVector);
    axisParticleIndices.upload(axisParticleVector);
6653
    
6654
6655
6656
6657
6658
6659
6660
6661
6662
6663
6664
6665
6666
6667
6668
6669
6670
6671
6672
6673
6674
    // Record exceptions and exclusions.

    vector<mm_float2> exceptionParamsVec;
    for (int i = 0; i < force.getNumExceptions(); i++) {
        int particle1, particle2;
        double sigma, epsilon;
        force.getExceptionParameters(i, particle1, particle2, sigma, epsilon);
        if (epsilon != 0.0) {
            exceptionParamsVec.push_back(mm_float2((float) sigma, (float) epsilon));
            exceptionAtoms.push_back(make_pair(particle1, particle2));
            isRealParticle[particle1] = true;
            isRealParticle[particle2] = true;
        }
        if (isRealParticle[particle1] && isRealParticle[particle2])
            excludedPairs.push_back(pair<int, int>(particle1, particle2));
    }
    numRealParticles = 0;
    for (int i = 0; i < isRealParticle.size(); i++)
        if (isRealParticle[i])
            numRealParticles++;
    int numExceptions = exceptionParamsVec.size();
peastman's avatar
peastman committed
6675
6676
6677
6678
    exclusions.initialize<cl_int>(cl, max(1, (int) excludedPairs.size()), "exclusions");
    exclusionStartIndex.initialize<cl_int>(cl, numRealParticles+1, "exclusionStartIndex");
    exceptionParticles.initialize<mm_int4>(cl, max(1, numExceptions), "exceptionParticles");
    exceptionParams.initialize<mm_float2>(cl, max(1, numExceptions), "exceptionParams");
6679
    if (numExceptions > 0)
peastman's avatar
peastman committed
6680
        exceptionParams.upload(exceptionParamsVec);
6681
    
6682
6683
6684
6685
    // Create data structures used for the neighbor list.

    int numAtomBlocks = (numRealParticles+31)/32;
    int elementSize = (cl.getUseDoublePrecision() ? sizeof(cl_double) : sizeof(cl_float));
peastman's avatar
peastman committed
6686
6687
6688
    blockCenter.initialize(cl, numAtomBlocks, 4*elementSize, "blockCenter");
    blockBoundingBox.initialize(cl, numAtomBlocks, 4*elementSize, "blockBoundingBox");
    sortedPos.initialize(cl, numRealParticles, 4*elementSize, "sortedPos");
6689
    maxNeighborBlocks = numRealParticles*2;
peastman's avatar
peastman committed
6690
6691
6692
    neighbors.initialize<cl_int>(cl, maxNeighborBlocks*32, "neighbors");
    neighborIndex.initialize<cl_int>(cl, maxNeighborBlocks, "neighborIndex");
    neighborBlockCount.initialize<cl_int>(cl, 1, "neighborBlockCount");
6693

6694
    // Create array for accumulating torques.
6695
    
peastman's avatar
peastman committed
6696
6697
    torque.initialize<cl_long>(cl, 3*cl.getPaddedNumAtoms(), "torque");
    cl.addAutoclearBuffer(torque);
6698
6699
6700
6701
6702
6703
6704
6705

    // Create the kernels.
    
    nonbondedMethod = force.getNonbondedMethod();
    bool useCutoff = (nonbondedMethod != GayBerneForce::NoCutoff);
    bool usePeriodic = (nonbondedMethod == GayBerneForce::CutoffPeriodic);
    map<string, string> defines;
    defines["USE_SWITCH"] = (useCutoff && force.getUseSwitchingFunction() ? "1" : "0");
6706
6707
    double cutoff = force.getCutoffDistance();
    defines["CUTOFF_SQUARED"] = cl.doubleToString(cutoff*cutoff);
6708
    if (useCutoff) {
6709
6710
6711
6712
        defines["USE_CUTOFF"] = 1;
        if (usePeriodic)
            defines["USE_PERIODIC"] = "1";
        
6713
6714
6715
6716
        // Compute the switching coefficients.
        
        if (force.getUseSwitchingFunction()) {
            defines["SWITCH_CUTOFF"] = cl.doubleToString(force.getSwitchingDistance());
6717
6718
6719
            defines["SWITCH_C3"] = cl.doubleToString(10/pow(force.getSwitchingDistance()-cutoff, 3.0));
            defines["SWITCH_C4"] = cl.doubleToString(15/pow(force.getSwitchingDistance()-cutoff, 4.0));
            defines["SWITCH_C5"] = cl.doubleToString(6/pow(force.getSwitchingDistance()-cutoff, 5.0));
6720
6721
        }
    }
6722
    defines["PADDED_NUM_ATOMS"] = cl.intToString(cl.getPaddedNumAtoms());
6723
6724
6725
6726
6727
    cl::Program program = cl.createProgram(OpenCLKernelSources::gayBerne, defines);
    framesKernel = cl::Kernel(program, "computeEllipsoidFrames");
    blockBoundsKernel = cl::Kernel(program, "findBlockBounds");
    neighborsKernel = cl::Kernel(program, "findNeighbors");
    forceKernel = cl::Kernel(program, "computeForce");
6728
    torqueKernel = cl::Kernel(program, "applyTorques");
6729
6730
    info = new ForceInfo(cl.getNonbondedUtilities().getNumForceBuffers(), force);
    cl.addForce(info);
6731
6732
6733
6734
6735
6736
6737
6738
6739
    cl.addReorderListener(new ReorderListener(*this));
}

double OpenCLCalcGayBerneForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
    if (!hasInitializedKernels) {
        hasInitializedKernels = true;
        sortAtoms();
        framesKernel.setArg<cl_int>(0, numRealParticles);
        framesKernel.setArg<cl::Buffer>(1, cl.getPosq().getDeviceBuffer());
peastman's avatar
peastman committed
6740
6741
6742
6743
6744
6745
6746
        framesKernel.setArg<cl::Buffer>(2, axisParticleIndices.getDeviceBuffer());
        framesKernel.setArg<cl::Buffer>(3, sigParams.getDeviceBuffer());
        framesKernel.setArg<cl::Buffer>(4, scale.getDeviceBuffer());
        framesKernel.setArg<cl::Buffer>(5, aMatrix.getDeviceBuffer());
        framesKernel.setArg<cl::Buffer>(6, bMatrix.getDeviceBuffer());
        framesKernel.setArg<cl::Buffer>(7, gMatrix.getDeviceBuffer());
        framesKernel.setArg<cl::Buffer>(8, sortedParticles.getDeviceBuffer());
6747
        blockBoundsKernel.setArg<cl_int>(0, numRealParticles);
peastman's avatar
peastman committed
6748
        blockBoundsKernel.setArg<cl::Buffer>(6, sortedParticles.getDeviceBuffer());
6749
        blockBoundsKernel.setArg<cl::Buffer>(7, cl.getPosq().getDeviceBuffer());
peastman's avatar
peastman committed
6750
6751
6752
6753
        blockBoundsKernel.setArg<cl::Buffer>(8, sortedPos.getDeviceBuffer());
        blockBoundsKernel.setArg<cl::Buffer>(9, blockCenter.getDeviceBuffer());
        blockBoundsKernel.setArg<cl::Buffer>(10, blockBoundingBox.getDeviceBuffer());
        blockBoundsKernel.setArg<cl::Buffer>(11, neighborBlockCount.getDeviceBuffer());
6754
6755
        neighborsKernel.setArg<cl_int>(0, numRealParticles);
        neighborsKernel.setArg<cl_int>(1, maxNeighborBlocks);
peastman's avatar
peastman committed
6756
6757
6758
6759
6760
6761
6762
6763
        neighborsKernel.setArg<cl::Buffer>(7, sortedPos.getDeviceBuffer());
        neighborsKernel.setArg<cl::Buffer>(8, blockCenter.getDeviceBuffer());
        neighborsKernel.setArg<cl::Buffer>(9, blockBoundingBox.getDeviceBuffer());
        neighborsKernel.setArg<cl::Buffer>(10, neighbors.getDeviceBuffer());
        neighborsKernel.setArg<cl::Buffer>(11, neighborIndex.getDeviceBuffer());
        neighborsKernel.setArg<cl::Buffer>(12, neighborBlockCount.getDeviceBuffer());
        neighborsKernel.setArg<cl::Buffer>(13, exclusions.getDeviceBuffer());
        neighborsKernel.setArg<cl::Buffer>(14, exclusionStartIndex.getDeviceBuffer());
6764
        int index = 0;
6765
        forceKernel.setArg<cl::Buffer>(index++, cl.getLongForceBuffer().getDeviceBuffer());
peastman's avatar
peastman committed
6766
        forceKernel.setArg<cl::Buffer>(index++, torque.getDeviceBuffer());
6767
        forceKernel.setArg<cl_int>(index++, numRealParticles);
6768
        forceKernel.setArg<cl_int>(index++, exceptionAtoms.size());
6769
        forceKernel.setArg<cl::Buffer>(index++, cl.getEnergyBuffer().getDeviceBuffer());
peastman's avatar
peastman committed
6770
6771
6772
6773
6774
6775
6776
6777
6778
6779
6780
        forceKernel.setArg<cl::Buffer>(index++, sortedPos.getDeviceBuffer());
        forceKernel.setArg<cl::Buffer>(index++, sigParams.getDeviceBuffer());
        forceKernel.setArg<cl::Buffer>(index++, epsParams.getDeviceBuffer());
        forceKernel.setArg<cl::Buffer>(index++, sortedParticles.getDeviceBuffer());
        forceKernel.setArg<cl::Buffer>(index++, aMatrix.getDeviceBuffer());
        forceKernel.setArg<cl::Buffer>(index++, bMatrix.getDeviceBuffer());
        forceKernel.setArg<cl::Buffer>(index++, gMatrix.getDeviceBuffer());
        forceKernel.setArg<cl::Buffer>(index++, exclusions.getDeviceBuffer());
        forceKernel.setArg<cl::Buffer>(index++, exclusionStartIndex.getDeviceBuffer());
        forceKernel.setArg<cl::Buffer>(index++, exceptionParticles.getDeviceBuffer());
        forceKernel.setArg<cl::Buffer>(index++, exceptionParams.getDeviceBuffer());
6781
6782
        if (nonbondedMethod != GayBerneForce::NoCutoff) {
            forceKernel.setArg<cl_int>(index++, maxNeighborBlocks);
peastman's avatar
peastman committed
6783
6784
6785
            forceKernel.setArg<cl::Buffer>(index++, neighbors.getDeviceBuffer());
            forceKernel.setArg<cl::Buffer>(index++, neighborIndex.getDeviceBuffer());
            forceKernel.setArg<cl::Buffer>(index++, neighborBlockCount.getDeviceBuffer());
6786
        }
6787
        index = 0;
6788
        torqueKernel.setArg<cl::Buffer>(index++, cl.getLongForceBuffer().getDeviceBuffer());
peastman's avatar
peastman committed
6789
        torqueKernel.setArg<cl::Buffer>(index++, torque.getDeviceBuffer());
6790
6791
        torqueKernel.setArg<cl_int>(index++, numRealParticles);
        torqueKernel.setArg<cl::Buffer>(index++, cl.getPosq().getDeviceBuffer());
peastman's avatar
peastman committed
6792
6793
        torqueKernel.setArg<cl::Buffer>(index++, axisParticleIndices.getDeviceBuffer());
        torqueKernel.setArg<cl::Buffer>(index++, sortedParticles.getDeviceBuffer());
6794
6795
    }
    cl.executeKernel(framesKernel, numRealParticles);
6796
6797
    setPeriodicBoxArgs(cl, blockBoundsKernel, 1);
    cl.executeKernel(blockBoundsKernel, (numRealParticles+31)/32);
6798
6799
6800
6801
6802
6803
6804
6805
6806
    if (nonbondedMethod == GayBerneForce::NoCutoff) {
        cl.executeKernel(forceKernel, cl.getNonbondedUtilities().getNumForceThreadBlocks()*cl.getNonbondedUtilities().getForceThreadBlockSize());
    }
    else {
        while (true) {
            setPeriodicBoxArgs(cl, neighborsKernel, 2);
            cl.executeKernel(neighborsKernel, numRealParticles);
            cl_int* count = (cl_int*) cl.getPinnedBuffer();
            cl::Event event;
peastman's avatar
peastman committed
6807
            cl.getQueue().enqueueReadBuffer(neighborBlockCount.getDeviceBuffer(), CL_FALSE, 0, neighborBlockCount.getSize()*neighborBlockCount.getElementSize(), count, NULL, &event);
6808
6809
6810
6811
6812
6813
6814
6815
6816
            setPeriodicBoxArgs(cl, forceKernel, 20);
            cl.executeKernel(forceKernel, cl.getNonbondedUtilities().getNumForceThreadBlocks()*cl.getNonbondedUtilities().getForceThreadBlockSize());
            event.wait();
            if (*count <= maxNeighborBlocks)
                break;
            
            // There wasn't enough room for the neighbor list, so we need to recreate it.

            maxNeighborBlocks = (int) ceil((*count)*1.1);
peastman's avatar
peastman committed
6817
6818
6819
6820
6821
6822
            neighbors.resize(maxNeighborBlocks*32);
            neighborIndex.resize(maxNeighborBlocks);
            neighborsKernel.setArg<cl::Buffer>(10, neighbors.getDeviceBuffer());
            neighborsKernel.setArg<cl::Buffer>(11, neighborIndex.getDeviceBuffer());
            forceKernel.setArg<cl::Buffer>(17, neighbors.getDeviceBuffer());
            forceKernel.setArg<cl::Buffer>(18, neighborIndex.getDeviceBuffer());
6823
        }
6824
    }
6825
    cl.executeKernel(torqueKernel, numRealParticles);
6826
6827
6828
6829
6830
6831
6832
6833
6834
6835
6836
6837
6838
6839
6840
6841
6842
6843
    return 0.0;
}

void OpenCLCalcGayBerneForceKernel::copyParametersToContext(ContextImpl& context, const GayBerneForce& force) {
    // Make sure the new parameters are acceptable.
    
    if (force.getNumParticles() != cl.getNumAtoms())
        throw OpenMMException("updateParametersInContext: The number of particles has changed");
    vector<int> exceptions;
    for (int i = 0; i < force.getNumExceptions(); i++) {
        int particle1, particle2;
        double sigma, epsilon;
        force.getExceptionParameters(i, particle1, particle2, sigma, epsilon);
        if (exceptionAtoms.size() > exceptions.size() && make_pair(particle1, particle2) == exceptionAtoms[exceptions.size()])
            exceptions.push_back(i);
        else if (epsilon != 0.0)
            throw OpenMMException("updateParametersInContext: The set of non-excluded exceptions has changed");
    }
6844
    int numExceptions = exceptionAtoms.size();
6845
6846
6847
6848
6849
6850
6851
6852
6853
6854
    
    // Record the per-particle parameters.
    
    vector<mm_float4> sigParamsVector(cl.getPaddedNumAtoms(), mm_float4(0, 0, 0, 0));
    vector<mm_float2> epsParamsVector(cl.getPaddedNumAtoms(), mm_float2(0, 0));
    vector<mm_float4> scaleVector(cl.getPaddedNumAtoms(), mm_float4(0, 0, 0, 0));
    for (int i = 0; i < force.getNumParticles(); i++) {
        int xparticle, yparticle;
        double sigma, epsilon, sx, sy, sz, ex, ey, ez;
        force.getParticleParameters(i, sigma, epsilon, xparticle, yparticle, sx, sy, sz, ex, ey, ez);
6855
        sigParamsVector[i] = mm_float4((float) (0.5*sigma), (float) (0.25*sx*sx), (float) (0.25*sy*sy), (float) (0.25*sz*sz));
6856
6857
        epsParamsVector[i] = mm_float2((float) sqrt(epsilon), (float) (0.125*(sx*sy + sz*sz)*sqrt(sx*sy)));
        scaleVector[i] = mm_float4((float) (1/sqrt(ex)), (float) (1/sqrt(ey)), (float) (1/sqrt(ez)), 0);
6858
6859
        if (epsilon != 0.0 && !isRealParticle[i])
            throw OpenMMException("updateParametersInContext: The set of ignored particles (ones with epsilon=0) has changed");
6860
    }
peastman's avatar
peastman committed
6861
6862
6863
    sigParams.upload(sigParamsVector);
    epsParams.upload(epsParamsVector);
    scale.upload(scaleVector);
6864
6865
6866
6867
    
    // Record the exceptions.
    
    if (numExceptions > 0) {
6868
        vector<mm_float2> exceptionParamsVec(numExceptions);
6869
        for (int i = 0; i < numExceptions; i++) {
6870
            int atom1, atom2;
6871
            double sigma, epsilon;
6872
6873
            force.getExceptionParameters(exceptions[i], atom1, atom2, sigma, epsilon);
            exceptionParamsVec[i] = mm_float2((float) sigma, (float) epsilon);
6874
        }
peastman's avatar
peastman committed
6875
        exceptionParams.upload(exceptionParamsVec);
6876
    }
6877
    cl.invalidateMolecules(info);
6878
6879
6880
6881
6882
6883
6884
6885
6886
6887
    sortAtoms();
}

void OpenCLCalcGayBerneForceKernel::sortAtoms() {
    // Sort the list of atoms by type to avoid thread divergence.  This is executed every time
    // the atoms are reordered.
    
    int nextIndex = 0;
    vector<cl_int> particles(cl.getPaddedNumAtoms(), 0);
    const vector<int>& order = cl.getAtomIndex();
6888
    vector<int> inverseOrder(order.size(), -1);
6889
6890
    for (int i = 0; i < cl.getNumAtoms(); i++) {
        int atom = order[i];
6891
6892
        if (isRealParticle[atom]) {
            inverseOrder[atom] = nextIndex;
6893
            particles[nextIndex++] = atom;
6894
        }
6895
    }
peastman's avatar
peastman committed
6896
    sortedParticles.upload(particles);
6897
    
6898
6899
6900
6901
6902
6903
6904
    // Update the list of exception particles.
    
    int numExceptions = exceptionAtoms.size();
    if (numExceptions > 0) {
        vector<mm_int4> exceptionParticlesVec(numExceptions);
        for (int i = 0; i < numExceptions; i++)
            exceptionParticlesVec[i] = mm_int4(exceptionAtoms[i].first, exceptionAtoms[i].second, inverseOrder[exceptionAtoms[i].first], inverseOrder[exceptionAtoms[i].second]);
peastman's avatar
peastman committed
6905
        exceptionParticles.upload(exceptionParticlesVec);
6906
6907
    }
    
6908
6909
6910
6911
    // Rebuild the list of exclusions.
    
    vector<vector<int> > excludedAtoms(numRealParticles);
    for (int i = 0; i < excludedPairs.size(); i++) {
6912
6913
        int first = inverseOrder[min(excludedPairs[i].first, excludedPairs[i].second)];
        int second = inverseOrder[max(excludedPairs[i].first, excludedPairs[i].second)];
6914
6915
6916
        excludedAtoms[first].push_back(second);
    }
    int index = 0;
peastman's avatar
peastman committed
6917
6918
    vector<int> exclusionVec(exclusions.getSize());
    vector<int> startIndexVec(exclusionStartIndex.getSize());
6919
6920
6921
6922
6923
6924
    for (int i = 0; i < numRealParticles; i++) {
        startIndexVec[i] = index;
        for (int j = 0; j < excludedAtoms[i].size(); j++)
            exclusionVec[index++] = excludedAtoms[i][j];
    }
    startIndexVec[numRealParticles] = index;
peastman's avatar
peastman committed
6925
6926
    exclusions.upload(exclusionVec);
    exclusionStartIndex.upload(startIndexVec);
6927
6928
}

6929
6930
6931
6932
6933
6934
6935
6936
6937
6938
6939
6940
6941
6942
6943
6944
6945
6946
6947
6948
class OpenCLCalcCustomCVForceKernel::ForceInfo : public OpenCLForceInfo {
public:
    ForceInfo(OpenCLForceInfo& force) : OpenCLForceInfo(0), force(force) {
    }
    bool areParticlesIdentical(int particle1, int particle2) {
        return force.areParticlesIdentical(particle1, particle2);
    }
    int getNumParticleGroups() {
        return force.getNumParticleGroups();
    }
    void getParticlesInGroup(int index, std::vector<int>& particles) {
        force.getParticlesInGroup(index, particles);
    }
    bool areGroupsIdentical(int group1, int group2) {
        return force.areGroupsIdentical(group1, group2);
    }
private:
    OpenCLForceInfo& force;
};

6949
6950
6951
6952
6953
class OpenCLCalcCustomCVForceKernel::ReorderListener : public OpenCLContext::ReorderListener {
public:
    ReorderListener(OpenCLContext& cl, OpenCLArray& invAtomOrder) : cl(cl), invAtomOrder(invAtomOrder) {
    }
    void execute() {
6954
        vector<cl_int> invOrder(cl.getPaddedNumAtoms());
6955
6956
6957
6958
6959
6960
6961
6962
6963
6964
        const vector<int>& order = cl.getAtomIndex();
        for (int i = 0; i < order.size(); i++)
            invOrder[order[i]] = i;
        invAtomOrder.upload(invOrder);
    }
private:
    OpenCLContext& cl;
    OpenCLArray& invAtomOrder;
};

6965
void OpenCLCalcCustomCVForceKernel::initialize(const System& system, const CustomCVForce& force, ContextImpl& innerContext) {
6966
6967
6968
6969
    int numCVs = force.getNumCollectiveVariables();
    cl.addForce(new OpenCLForceInfo(1));
    for (int i = 0; i < force.getNumGlobalParameters(); i++)
        globalParameterNames.push_back(force.getGlobalParameterName(i));
6970
6971
    for (int i = 0; i < numCVs; i++)
        variableNames.push_back(force.getCollectiveVariableName(i));
6972
6973
6974
    for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++) {
        string name = force.getEnergyParameterDerivativeName(i);
        paramDerivNames.push_back(name);
6975
        cl.addEnergyParameterDerivative(name);
6976
    }
6977
6978
6979
6980
6981
6982
6983
6984
6985
6986
6987
6988
6989
6990
6991
6992
6993
6994
6995
6996
6997
6998
6999

    // Create custom functions for the tabulated functions.

    map<string, Lepton::CustomFunction*> functions;
    for (int i = 0; i < (int) force.getNumTabulatedFunctions(); i++)
        functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));

    // Create the expressions.

    Lepton::ParsedExpression energyExpr = Lepton::Parser::parse(force.getEnergyFunction(), functions);
    energyExpression = energyExpr.createProgram();
    variableDerivExpressions.clear();
    for (auto& name : variableNames)
        variableDerivExpressions.push_back(energyExpr.differentiate(name).optimize().createProgram());
    paramDerivExpressions.clear();
    for (auto& name : paramDerivNames)
        paramDerivExpressions.push_back(energyExpr.differentiate(name).optimize().createProgram());

    // Delete the custom functions.

    for (auto& function : functions)
        delete function.second;

7000
7001
7002
7003
7004
    // Copy parameter derivatives from the inner context.

    OpenCLContext& cl2 = *reinterpret_cast<OpenCLPlatform::PlatformData*>(innerContext.getPlatformData())->contexts[0];
    for (auto& param : cl2.getEnergyParamDerivNames())
        cl.addEnergyParameterDerivative(param);
7005
7006
7007
7008
    
    // Create arrays for storing information.
    
    int elementSize = (cl.getUseDoublePrecision() || cl.getUseMixedPrecision() ? sizeof(double) : sizeof(float));
peastman's avatar
peastman committed
7009
    cvForces.resize(numCVs);
7010
    for (int i = 0; i < numCVs; i++)
peastman's avatar
peastman committed
7011
7012
7013
        cvForces[i].initialize(cl, cl.getNumAtoms(), 4*elementSize, "cvForce");
    invAtomOrder.initialize<cl_int>(cl, cl.getPaddedNumAtoms(), "invAtomOrder");
    innerInvAtomOrder.initialize<cl_int>(cl, cl.getPaddedNumAtoms(), "innerInvAtomOrder");
7014
7015
7016
7017
7018
7019
7020
7021
7022
7023
7024
7025
7026
7027
7028
    
    // Create the kernels.
    
    stringstream args, add;
    for (int i = 0; i < numCVs; i++) {
        args << ", __global real4* restrict force" << i << ", real dEdV" << i;
        add << "f += force" << i << "[i]*dEdV" << i << ";\n";
    }
    map<string, string> replacements;
    replacements["PARAMETER_ARGUMENTS"] = args.str();
    replacements["ADD_FORCES"] = add.str();
    cl::Program program = cl.createProgram(cl.replaceStrings(OpenCLKernelSources::customCVForce, replacements));
    copyStateKernel = cl::Kernel(program, "copyState");
    copyForcesKernel = cl::Kernel(program, "copyForces");
    addForcesKernel = cl::Kernel(program, "addForces");
7029
7030
7031
7032
7033

    // This context needs to respect all forces in the inner context when reordering atoms.

    for (OpenCLForceInfo* info : cl2.getForceInfos())
        cl.addForce(new ForceInfo(*info));
7034
7035
7036
7037
7038
7039
7040
7041
7042
7043
7044
}

double OpenCLCalcCustomCVForceKernel::execute(ContextImpl& context, ContextImpl& innerContext, bool includeForces, bool includeEnergy) {
    copyState(context, innerContext);
    int numCVs = variableNames.size();
    int numAtoms = cl.getNumAtoms();
    OpenCLContext& cl2 = *reinterpret_cast<OpenCLPlatform::PlatformData*>(innerContext.getPlatformData())->contexts[0];
    vector<double> cvValues;
    vector<map<string, double> > cvDerivs(numCVs);
    for (int i = 0; i < numCVs; i++) {
        cvValues.push_back(innerContext.calcForcesAndEnergy(true, true, 1<<i));
peastman's avatar
peastman committed
7045
        copyForcesKernel.setArg<cl::Buffer>(0, cvForces[i].getDeviceBuffer());
7046
7047
7048
7049
7050
7051
7052
7053
7054
7055
7056
7057
7058
7059
7060
7061
7062
7063
7064
        cl.executeKernel(copyForcesKernel, numAtoms);
        innerContext.getEnergyParameterDerivatives(cvDerivs[i]);
    }
    
    // Compute the energy and forces.
    
    map<string, double> variables;
    for (auto& name : globalParameterNames)
        variables[name] = context.getParameter(name);
    for (int i = 0; i < numCVs; i++)
        variables[variableNames[i]] = cvValues[i];
    double energy = energyExpression.evaluate(variables);
    for (int i = 0; i < numCVs; i++) {
        double dEdV = variableDerivExpressions[i].evaluate(variables);
        if (cl.getUseDoublePrecision())
            addForcesKernel.setArg<cl_double>(2*i+3, dEdV);
        else
            addForcesKernel.setArg<cl_float>(2*i+3, dEdV);
    }
7065
    cl.executeKernel(addForcesKernel, numAtoms);
7066
7067
7068
7069
7070
7071
7072
7073
7074
7075
7076
7077
7078
7079
7080
7081
    
    // Compute the energy parameter derivatives.
    
    map<string, double>& energyParamDerivs = cl.getEnergyParamDerivWorkspace();
    for (int i = 0; i < paramDerivExpressions.size(); i++)
        energyParamDerivs[paramDerivNames[i]] += paramDerivExpressions[i].evaluate(variables);
    for (int i = 0; i < numCVs; i++) {
        double dEdV = variableDerivExpressions[i].evaluate(variables);
        for (auto& deriv : cvDerivs[i])
            energyParamDerivs[deriv.first] += dEdV*deriv.second;
    }
    return energy;
}

void OpenCLCalcCustomCVForceKernel::copyState(ContextImpl& context, ContextImpl& innerContext) {
    int numAtoms = cl.getNumAtoms();
7082
    OpenCLContext& cl2 = *reinterpret_cast<OpenCLPlatform::PlatformData*>(innerContext.getPlatformData())->contexts[0];
7083
7084
7085
7086
7087
    if (!hasInitializedKernels) {
        hasInitializedKernels = true;
        
        // Initialize the listeners.
        
peastman's avatar
peastman committed
7088
7089
        ReorderListener* listener1 = new ReorderListener(cl, invAtomOrder);
        ReorderListener* listener2 = new ReorderListener(cl2, innerInvAtomOrder);
7090
7091
7092
7093
7094
7095
7096
7097
7098
7099
7100
7101
        cl.addReorderListener(listener1);
        cl2.addReorderListener(listener2);
        listener1->execute();
        listener2->execute();
        
        // Initialize the kernels.
        
        copyStateKernel.setArg<cl::Buffer>(0, cl.getPosq().getDeviceBuffer());
        copyStateKernel.setArg<cl::Buffer>(2, cl.getVelm().getDeviceBuffer());
        copyStateKernel.setArg<cl::Buffer>(3, cl.getAtomIndexArray().getDeviceBuffer());
        copyStateKernel.setArg<cl::Buffer>(4, cl2.getPosq().getDeviceBuffer());
        copyStateKernel.setArg<cl::Buffer>(6, cl2.getVelm().getDeviceBuffer());
peastman's avatar
peastman committed
7102
        copyStateKernel.setArg<cl::Buffer>(7, innerInvAtomOrder.getDeviceBuffer());
7103
7104
7105
7106
7107
7108
7109
7110
7111
7112
        copyStateKernel.setArg<cl_int>(8, numAtoms);
        if (cl.getUseMixedPrecision()) {
            copyStateKernel.setArg<cl::Buffer>(1, cl.getPosqCorrection().getDeviceBuffer());
            copyStateKernel.setArg<cl::Buffer>(5, cl2.getPosqCorrection().getDeviceBuffer());
        }
        else {
            copyStateKernel.setArg<void*>(1, NULL);
            copyStateKernel.setArg<void*>(5, NULL);
        }

peastman's avatar
peastman committed
7113
        copyForcesKernel.setArg<cl::Buffer>(1, invAtomOrder.getDeviceBuffer());
7114
7115
7116
7117
7118
7119
7120
        copyForcesKernel.setArg<cl::Buffer>(2, cl2.getForce().getDeviceBuffer());
        copyForcesKernel.setArg<cl::Buffer>(3, cl2.getAtomIndexArray().getDeviceBuffer());
        copyForcesKernel.setArg<cl_int>(4, numAtoms);

        addForcesKernel.setArg<cl::Buffer>(0, cl.getForce().getDeviceBuffer());
        addForcesKernel.setArg<cl_int>(1, numAtoms);
        for (int i = 0; i < cvForces.size(); i++)
peastman's avatar
peastman committed
7121
            addForcesKernel.setArg<cl::Buffer>(2*i+2, cvForces[i].getDeviceBuffer());
7122
7123
7124
7125
7126
7127
7128
7129
7130
7131
7132
    }
    cl.executeKernel(copyStateKernel, numAtoms);
    Vec3 a, b, c;
    context.getPeriodicBoxVectors(a, b, c);
    innerContext.setPeriodicBoxVectors(a, b, c);
    innerContext.setTime(context.getTime());
    map<string, double> innerParameters = innerContext.getParameters();
    for (auto& param : innerParameters)
        innerContext.setParameter(param.first, context.getParameter(param.first));
}

7133
7134
7135
void OpenCLCalcCustomCVForceKernel::copyParametersToContext(ContextImpl& context, const CustomCVForce& force) {
    // Create custom functions for the tabulated functions.

7136
    map<string, CustomFunction*> functions;
7137
7138
7139
    for (int i = 0; i < (int) force.getNumTabulatedFunctions(); i++)
        functions[force.getTabulatedFunctionName(i)] = createReferenceTabulatedFunction(force.getTabulatedFunction(i));

7140
    // Replace tabulated functions in the expressions.
7141

7142
7143
7144
7145
7146
    replaceFunctionsInExpression(functions, energyExpression);
    for (auto& expression : variableDerivExpressions)
        replaceFunctionsInExpression(functions, expression);
    for (auto& expression : paramDerivExpressions)
        replaceFunctionsInExpression(functions, expression);
7147
7148
7149
7150
7151
7152
7153

    // Delete the custom functions.

    for (auto& function : functions)
        delete function.second;
}

peastman's avatar
peastman committed
7154
7155
7156
7157
7158
7159
7160
7161
7162
7163
7164
7165
7166
7167
7168
7169
7170
7171
7172
7173
7174
7175
7176
7177
7178
7179
7180
7181
class OpenCLCalcRMSDForceKernel::ForceInfo : public OpenCLForceInfo {
public:
    ForceInfo(const RMSDForce& force) : OpenCLForceInfo(0), force(force) {
        updateParticles();
    }
    void updateParticles() {
        particles.clear();
        for (int i : force.getParticles())
            particles.insert(i);
    }
    bool areParticlesIdentical(int particle1, int particle2) {
        bool include1 = (particles.find(particle1) != particles.end());
        bool include2 = (particles.find(particle2) != particles.end());
        return (include1 == include2);
    }
private:
    const RMSDForce& force;
    set<int> particles;
};

void OpenCLCalcRMSDForceKernel::initialize(const System& system, const RMSDForce& force) {
    // Create data structures.
    
    bool useDouble = cl.getUseDoublePrecision();
    int elementSize = (useDouble ? sizeof(cl_double) : sizeof(cl_float));
    int numParticles = force.getParticles().size();
    if (numParticles == 0)
        numParticles = system.getNumParticles();
peastman's avatar
peastman committed
7182
7183
7184
    referencePos.initialize(cl, system.getNumParticles(), 4*elementSize, "referencePos");
    particles.initialize<cl_int>(cl, numParticles, "particles");
    buffer.initialize(cl, 13, elementSize, "buffer");
peastman's avatar
peastman committed
7185
7186
7187
7188
7189
7190
7191
7192
7193
7194
7195
7196
7197
7198
7199
7200
7201
7202
7203
7204
7205
7206
7207
7208
7209
7210
7211
7212
    recordParameters(force);
    info = new ForceInfo(force);
    cl.addForce(info);
    
    // Create the kernels.

    cl::Program program = cl.createProgram(OpenCLKernelSources::rmsd);
    kernel1 = cl::Kernel(program, "computeRMSDPart1");
    kernel2 = cl::Kernel(program, "computeRMSDForces");
}

void OpenCLCalcRMSDForceKernel::recordParameters(const RMSDForce& force) {
    // Record the parameters and center the reference positions.
    
    vector<int> particleVec = force.getParticles();
    if (particleVec.size() == 0)
        for (int i = 0; i < cl.getNumAtoms(); i++)
            particleVec.push_back(i);
    vector<Vec3> centeredPositions = force.getReferencePositions();
    Vec3 center;
    for (int i : particleVec)
        center += centeredPositions[i];
    center /= particleVec.size();
    for (Vec3& p : centeredPositions)
        p -= center;

    // Upload them to the device.

peastman's avatar
peastman committed
7213
    particles.upload(particleVec);
peastman's avatar
peastman committed
7214
7215
7216
7217
    vector<mm_double4> pos;
    for (Vec3 p : centeredPositions)
        pos.push_back(mm_double4(p[0], p[1], p[2], 0));
    referencePos.upload(pos, true, true);
peastman's avatar
peastman committed
7218
7219
7220
7221
7222
7223
7224
7225
7226
7227
7228
7229
7230
7231
7232
7233
7234
7235
7236
7237

    // Record the sum of the norms of the reference positions.

    sumNormRef = 0.0;
    for (int i : particleVec) {
        Vec3 p = centeredPositions[i];
        sumNormRef += p.dot(p);
    }
}

double OpenCLCalcRMSDForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
    if (cl.getUseDoublePrecision())
        return executeImpl<double>(context);
    return executeImpl<float>(context);
}

template <class REAL>
double OpenCLCalcRMSDForceKernel::executeImpl(ContextImpl& context) {
    // Execute the first kernel.

peastman's avatar
peastman committed
7238
    int numParticles = particles.getSize();
Peter Eastman's avatar
Peter Eastman committed
7239
    int blockSize = min(256, (int) kernel1.getWorkGroupInfo<CL_KERNEL_WORK_GROUP_SIZE>(cl.getDevice()));
peastman's avatar
peastman committed
7240
7241
    kernel1.setArg<cl_int>(0, numParticles);
    kernel1.setArg<cl::Buffer>(1, cl.getPosq().getDeviceBuffer());
peastman's avatar
peastman committed
7242
7243
7244
    kernel1.setArg<cl::Buffer>(2, referencePos.getDeviceBuffer());
    kernel1.setArg<cl::Buffer>(3, particles.getDeviceBuffer());
    kernel1.setArg<cl::Buffer>(4, buffer.getDeviceBuffer());
peastman's avatar
peastman committed
7245
7246
7247
7248
7249
7250
7251
    kernel1.setArg(5, blockSize*sizeof(REAL), NULL);
    cl.executeKernel(kernel1, blockSize, blockSize);
    
    // Download the results, build the F matrix, and find the maximum eigenvalue
    // and eigenvector.

    vector<REAL> b;
peastman's avatar
peastman committed
7252
    buffer.download(b);
peastman's avatar
peastman committed
7253
7254
7255
7256
7257
7258
7259
7260
7261
7262
7263
7264
7265
7266
7267
7268
7269
7270
7271
7272
7273
7274
7275
7276
7277
7278
    Array2D<double> F(4, 4);
    F[0][0] =  b[0*3+0] + b[1*3+1] + b[2*3+2];
    F[1][0] =  b[1*3+2] - b[2*3+1];
    F[2][0] =  b[2*3+0] - b[0*3+2];
    F[3][0] =  b[0*3+1] - b[1*3+0];
    F[0][1] =  b[1*3+2] - b[2*3+1];
    F[1][1] =  b[0*3+0] - b[1*3+1] - b[2*3+2];
    F[2][1] =  b[0*3+1] + b[1*3+0];
    F[3][1] =  b[0*3+2] + b[2*3+0];
    F[0][2] =  b[2*3+0] - b[0*3+2];
    F[1][2] =  b[0*3+1] + b[1*3+0];
    F[2][2] = -b[0*3+0] + b[1*3+1] - b[2*3+2];
    F[3][2] =  b[1*3+2] + b[2*3+1];
    F[0][3] =  b[0*3+1] - b[1*3+0];
    F[1][3] =  b[0*3+2] + b[2*3+0];
    F[2][3] =  b[1*3+2] + b[2*3+1];
    F[3][3] = -b[0*3+0] - b[1*3+1] + b[2*3+2];
    JAMA::Eigenvalue<double> eigen(F);
    Array1D<double> values;
    eigen.getRealEigenvalues(values);
    Array2D<double> vectors;
    eigen.getV(vectors);

    // Compute the RMSD.

    double msd = (sumNormRef+b[9]-2*values[3])/numParticles;
7279
7280
7281
7282
7283
    if (msd < 1e-20) {
        // The particles are perfectly aligned, so all the forces should be zero.
        // Numerical error can lead to NaNs, so just return 0 now.
        return 0.0;
    }
peastman's avatar
peastman committed
7284
7285
7286
7287
7288
7289
7290
7291
7292
7293
7294
7295
7296
7297
7298
7299
7300
7301
7302
7303
7304
7305
    double rmsd = sqrt(msd);
    b[9] = rmsd;

    // Compute the rotation matrix.

    double q[] = {vectors[0][3], vectors[1][3], vectors[2][3], vectors[3][3]};
    double q00 = q[0]*q[0], q01 = q[0]*q[1], q02 = q[0]*q[2], q03 = q[0]*q[3];
    double q11 = q[1]*q[1], q12 = q[1]*q[2], q13 = q[1]*q[3];
    double q22 = q[2]*q[2], q23 = q[2]*q[3];
    double q33 = q[3]*q[3];
    b[0] = q00+q11-q22-q33;
    b[1] = 2*(q12-q03);
    b[2] = 2*(q13+q02);
    b[3] = 2*(q12+q03);
    b[4] = q00-q11+q22-q33;
    b[5] = 2*(q23-q01);
    b[6] = 2*(q13-q02);
    b[7] = 2*(q23+q01);
    b[8] = q00-q11-q22+q33;

    // Upload it to the device and invoke the kernel to apply forces.
    
peastman's avatar
peastman committed
7306
    buffer.upload(b);
peastman's avatar
peastman committed
7307
7308
    kernel2.setArg<cl_int>(0, numParticles);
    kernel2.setArg<cl::Buffer>(1, cl.getPosq().getDeviceBuffer());
peastman's avatar
peastman committed
7309
7310
7311
    kernel2.setArg<cl::Buffer>(2, referencePos.getDeviceBuffer());
    kernel2.setArg<cl::Buffer>(3, particles.getDeviceBuffer());
    kernel2.setArg<cl::Buffer>(4, buffer.getDeviceBuffer());
peastman's avatar
peastman committed
7312
7313
7314
7315
7316
7317
    kernel2.setArg<cl::Buffer>(5, cl.getForceBuffers().getDeviceBuffer());
    cl.executeKernel(kernel2, numParticles);
    return rmsd;
}

void OpenCLCalcRMSDForceKernel::copyParametersToContext(ContextImpl& context, const RMSDForce& force) {
peastman's avatar
peastman committed
7318
    if (referencePos.getSize() != force.getReferencePositions().size())
peastman's avatar
peastman committed
7319
7320
7321
7322
        throw OpenMMException("updateParametersInContext: The number of reference positions has changed");
    int numParticles = force.getParticles().size();
    if (numParticles == 0)
        numParticles = context.getSystem().getNumParticles();
peastman's avatar
peastman committed
7323
7324
    if (numParticles != particles.getSize())
        particles.resize(numParticles);
peastman's avatar
peastman committed
7325
7326
7327
7328
7329
7330
7331
7332
    recordParameters(force);
    
    // Mark that the current reordering may be invalid.
    
    info->updateParticles();
    cl.invalidateMolecules(info);
}

7333
7334
7335
7336
OpenCLIntegrateVerletStepKernel::~OpenCLIntegrateVerletStepKernel() {
}

void OpenCLIntegrateVerletStepKernel::initialize(const System& system, const VerletIntegrator& integrator) {
7337
    cl.getPlatformData().initializeContexts(system);
7338
    cl::Program program = cl.createProgram(OpenCLKernelSources::verlet, "");
7339
7340
    kernel1 = cl::Kernel(program, "integrateVerletPart1");
    kernel2 = cl::Kernel(program, "integrateVerletPart2");
7341
7342
7343
}

void OpenCLIntegrateVerletStepKernel::execute(ContextImpl& context, const VerletIntegrator& integrator) {
7344
    OpenCLIntegrationUtilities& integration = cl.getIntegrationUtilities();
7345
7346
    int numAtoms = cl.getNumAtoms();
    double dt = integrator.getStepSize();
7347
7348
7349
    if (!hasInitializedKernels) {
        hasInitializedKernels = true;
        kernel1.setArg<cl_int>(0, numAtoms);
7350
        kernel1.setArg<cl::Buffer>(1, cl.getIntegrationUtilities().getStepSize().getDeviceBuffer());
7351
        kernel1.setArg<cl::Buffer>(2, cl.getPosq().getDeviceBuffer());
7352
7353
7354
7355
        setPosqCorrectionArg(cl, kernel1, 3);
        kernel1.setArg<cl::Buffer>(4, cl.getVelm().getDeviceBuffer());
        kernel1.setArg<cl::Buffer>(5, cl.getForce().getDeviceBuffer());
        kernel1.setArg<cl::Buffer>(6, integration.getPosDelta().getDeviceBuffer());
7356
        kernel2.setArg<cl_int>(0, numAtoms);
7357
        kernel2.setArg<cl::Buffer>(1, cl.getIntegrationUtilities().getStepSize().getDeviceBuffer());
7358
        kernel2.setArg<cl::Buffer>(2, cl.getPosq().getDeviceBuffer());
7359
7360
7361
        setPosqCorrectionArg(cl, kernel2, 3);
        kernel2.setArg<cl::Buffer>(4, cl.getVelm().getDeviceBuffer());
        kernel2.setArg<cl::Buffer>(5, integration.getPosDelta().getDeviceBuffer());
7362
    }
7363
    cl.getIntegrationUtilities().setNextStepSize(dt);
7364
7365
7366
7367
7368
7369
7370

    // Call the first integration kernel.

    cl.executeKernel(kernel1, numAtoms);

    // Apply constraints.

7371
    integration.applyConstraints(integrator.getConstraintTolerance());
7372
7373
7374
7375

    // Call the second integration kernel.

    cl.executeKernel(kernel2, numAtoms);
7376
    integration.computeVirtualSites();
7377
7378
7379
7380
7381

    // Update the time and step count.

    cl.setTime(cl.getTime()+dt);
    cl.setStepCount(cl.getStepCount()+1);
7382
    cl.reorderAtoms();
7383
7384
7385
7386
7387
7388
    
    // Reduce UI lag.
    
#ifdef WIN32
    cl.getQueue().flush();
#endif
7389
7390
}

7391
7392
7393
7394
double OpenCLIntegrateVerletStepKernel::computeKineticEnergy(ContextImpl& context, const VerletIntegrator& integrator) {
    return cl.getIntegrationUtilities().computeKineticEnergy(0.5*integrator.getStepSize());
}

7395
7396
7397
void OpenCLIntegrateVelocityVerletStepKernel::initialize(const System& system, const NoseHooverIntegrator& integrator) {
    cl.getPlatformData().initializeContexts(system);
    map<string, string> defines;
7398
    defines["BOLTZ"] = cl.doubleToString(BOLTZ);
7399
7400
7401
7402
    cl::Program program = cl.createProgram(OpenCLKernelSources::velocityVerlet, defines, "");
    kernel1 = cl::Kernel(program, "integrateVelocityVerletPart1");
    kernel2 = cl::Kernel(program, "integrateVelocityVerletPart2");
    kernel3 = cl::Kernel(program, "integrateVelocityVerletPart3");
7403
7404
7405
    kernelHardWall = cl::Kernel(program, "integrateVelocityVerletHardWall");
    prevMaxPairDistance = (cl_float) -1.0;
    maxPairDistanceBuffer.initialize<cl_float>(cl, 1, "maxPairDistanceBuffer");
7406
7407
7408
7409
7410
7411
7412
7413
7414
7415
}

void OpenCLIntegrateVelocityVerletStepKernel::execute(ContextImpl& context, const NoseHooverIntegrator& integrator, bool &forcesAreValid) {
    OpenCLIntegrationUtilities& integration = cl.getIntegrationUtilities();
    int paddedNumAtoms = cl.getPaddedNumAtoms();
    double dt = integrator.getStepSize();
    cl.getIntegrationUtilities().setNextStepSize(dt);

    if( !forcesAreValid ) context.calcForcesAndEnergy(true, false);

7416
7417
    const auto& atomList = integrator.getAllThermostatedIndividualParticles();
    const auto& pairList = integrator.getAllThermostatedPairs();
7418
7419
7420
7421
7422
7423
7424
7425
7426
7427
7428
    int numAtoms = atomList.size();
    int numPairs = pairList.size();
    int numParticles = numAtoms + 2*numPairs;
    float maxPairDistance = integrator.getMaximumPairDistance();
    // Make sure atom and pair metadata is uploaded and has the correct dimensions
    if (prevMaxPairDistance != maxPairDistance) {
        std::vector<float> tmp(1, maxPairDistance);
        maxPairDistanceBuffer.upload(tmp);
        prevMaxPairDistance = maxPairDistance;
    }
    if (numAtoms !=0 && (!atomListBuffer.isInitialized() || atomListBuffer.getSize() != numAtoms)) {
7429
7430
7431
7432
7433
        if (atomListBuffer.isInitialized()) {
            atomListBuffer.resize(atomList.size());
        } else {
            atomListBuffer.initialize<cl_int>(cl, atomList.size(), "atomListBuffer");
        }
7434
7435
7436
        atomListBuffer.upload(atomList);
    }
    if (numPairs !=0 && (!pairListBuffer.isInitialized() || pairListBuffer.getSize() != numPairs)) {
7437
7438
7439
7440
7441
7442
7443
        if (pairListBuffer.isInitialized()) {
            pairListBuffer.resize(pairList.size());
            pairTemperatureBuffer.resize(pairList.size());
        } else {
            pairListBuffer.initialize<mm_int2>(cl, pairList.size(), "pairListBuffer");
            pairTemperatureBuffer.initialize<cl_float>(cl, pairList.size(), "pairTemperatureBuffer");
        }
7444
7445
7446
7447
7448
7449
7450
7451
7452
7453
7454
        std::vector<mm_int2> tmp;
        std::vector<float> tmp2;
        for(const auto &pair : pairList) {
            tmp.push_back(mm_int2(std::get<0>(pair), std::get<1>(pair)));
            tmp2.push_back(std::get<2>(pair));
        }
        pairListBuffer.upload(tmp);
        pairTemperatureBuffer.upload(tmp2);
    }

//// Call the first integration kernel.
7455
    kernel1.setArg<cl_int>(0, numAtoms);
7456
7457
7458
7459
7460
7461
7462
7463
7464
7465
7466
7467
7468
7469
7470
7471
7472
7473
7474
    kernel1.setArg<cl_int>(1, numPairs);
    kernel1.setArg<cl_int>(2, paddedNumAtoms);
    kernel1.setArg<cl::Buffer>(3, cl.getIntegrationUtilities().getStepSize().getDeviceBuffer());
    kernel1.setArg<cl::Buffer>(4, cl.getPosq().getDeviceBuffer());
    setPosqCorrectionArg(cl, kernel1, 5);
    kernel1.setArg<cl::Buffer>(6, cl.getVelm().getDeviceBuffer());
    kernel1.setArg<cl::Buffer>(7, cl.getForce().getDeviceBuffer());
    kernel1.setArg<cl::Buffer>(8, integration.getPosDelta().getDeviceBuffer());
    if (numAtoms > 0) {
        kernel1.setArg<cl::Buffer>(9, atomListBuffer.getDeviceBuffer());
    } else {
        kernel1.setArg<void*>(9, NULL);
    }
    if (numPairs > 0) {
        kernel1.setArg<cl::Buffer>(10, pairListBuffer.getDeviceBuffer());
    } else {
        kernel1.setArg<void*>(10, NULL);
    }
    cl.executeKernel(kernel1, std::max(numAtoms, numPairs));
7475
7476
7477
7478
7479
7480

    //// Apply constraints.

    integration.applyConstraints(integrator.getConstraintTolerance());

    //// Call the second integration kernel.
7481
    kernel2.setArg<cl_int>(0, numParticles);
7482
7483
7484
7485
7486
    kernel2.setArg<cl::Buffer>(1, cl.getIntegrationUtilities().getStepSize().getDeviceBuffer());
    kernel2.setArg<cl::Buffer>(2, cl.getPosq().getDeviceBuffer());
    setPosqCorrectionArg(cl, kernel2, 3);
    kernel2.setArg<cl::Buffer>(4, cl.getVelm().getDeviceBuffer());
    kernel2.setArg<cl::Buffer>(5, integration.getPosDelta().getDeviceBuffer());
7487
7488
7489
7490
7491
7492
7493
7494
7495
7496
7497
7498
7499
7500
7501
    cl.executeKernel(kernel2, numParticles);

    if (numPairs > 0) {
        //// Enforce hard wall constraint
        kernelHardWall.setArg<cl_int>(0, numPairs);
        kernelHardWall.setArg<cl::Buffer>(1, maxPairDistanceBuffer.getDeviceBuffer());
        kernelHardWall.setArg<cl::Buffer>(2, cl.getIntegrationUtilities().getStepSize().getDeviceBuffer());
        kernelHardWall.setArg<cl::Buffer>(3, cl.getPosq().getDeviceBuffer());
        setPosqCorrectionArg(cl, kernelHardWall, 4);
        kernelHardWall.setArg<cl::Buffer>(5, cl.getVelm().getDeviceBuffer());
        kernelHardWall.setArg<cl::Buffer>(6, pairListBuffer.getDeviceBuffer());
        kernelHardWall.setArg<cl::Buffer>(7, pairTemperatureBuffer.getDeviceBuffer());
        cl.executeKernel(kernelHardWall, numPairs);
    }

7502
7503
7504
7505
7506
7507
7508
7509
7510

    integration.computeVirtualSites();

    //// Update forces
    context.calcForcesAndEnergy(true, false);
    forcesAreValid = true;

    //// Call the third integration kernel.
    kernel3.setArg<cl_int>(0, numAtoms);
7511
7512
7513
7514
7515
7516
7517
7518
7519
7520
7521
7522
7523
7524
7525
7526
7527
7528
7529
    kernel3.setArg<cl_int>(1, numPairs);
    kernel3.setArg<cl_int>(2, paddedNumAtoms);
    kernel3.setArg<cl::Buffer>(3, cl.getIntegrationUtilities().getStepSize().getDeviceBuffer());
    kernel3.setArg<cl::Buffer>(4, cl.getPosq().getDeviceBuffer());
    setPosqCorrectionArg(cl, kernel3, 5);
    kernel3.setArg<cl::Buffer>(6, cl.getVelm().getDeviceBuffer());
    kernel3.setArg<cl::Buffer>(7, cl.getForce().getDeviceBuffer());
    kernel3.setArg<cl::Buffer>(8, integration.getPosDelta().getDeviceBuffer());
    if (numAtoms > 0) {
        kernel3.setArg<cl::Buffer>(9, atomListBuffer.getDeviceBuffer());
    } else {
        kernel3.setArg<void*>(9, NULL);
    }
    if (numPairs > 0) {
        kernel3.setArg<cl::Buffer>(10, pairListBuffer.getDeviceBuffer());
    } else {
        kernel3.setArg<void*>(10, NULL);
    }
    cl.executeKernel(kernel3, std::max(numAtoms, numPairs));
7530
7531
7532
7533
7534
7535
7536
7537
7538
7539
7540
7541
7542
7543

    integration.applyVelocityConstraints(integrator.getConstraintTolerance());

    //// Update the time and step count.

    cl.setTime(cl.getTime()+dt);
    cl.setStepCount(cl.getStepCount()+1);
    cl.reorderAtoms();
}

double OpenCLIntegrateVelocityVerletStepKernel::computeKineticEnergy(ContextImpl& context, const NoseHooverIntegrator& integrator) {
    return cl.getIntegrationUtilities().computeKineticEnergy(0);
}

7544
void OpenCLIntegrateLangevinStepKernel::initialize(const System& system, const LangevinIntegrator& integrator) {
7545
    cl.getPlatformData().initializeContexts(system);
7546
7547
    cl.getIntegrationUtilities().initRandomNumberGenerator(integrator.getRandomNumberSeed());
    map<string, string> defines;
7548
7549
    defines["NUM_ATOMS"] = cl.intToString(cl.getNumAtoms());
    defines["PADDED_NUM_ATOMS"] = cl.intToString(cl.getPaddedNumAtoms());
7550
    cl::Program program = cl.createProgram(OpenCLKernelSources::langevin, defines, "");
7551
7552
    kernel1 = cl::Kernel(program, "integrateLangevinPart1");
    kernel2 = cl::Kernel(program, "integrateLangevinPart2");
peastman's avatar
peastman committed
7553
    params.initialize(cl, 3, cl.getUseDoublePrecision() || cl.getUseMixedPrecision() ? sizeof(cl_double) : sizeof(cl_float), "langevinParams");
7554
7555
7556
7557
    prevStepSize = -1.0;
}

void OpenCLIntegrateLangevinStepKernel::execute(ContextImpl& context, const LangevinIntegrator& integrator) {
7558
    OpenCLIntegrationUtilities& integration = cl.getIntegrationUtilities();
7559
    int numAtoms = cl.getNumAtoms();
7560
7561
    if (!hasInitializedKernels) {
        hasInitializedKernels = true;
7562
7563
7564
        kernel1.setArg<cl::Buffer>(0, cl.getVelm().getDeviceBuffer());
        kernel1.setArg<cl::Buffer>(1, cl.getForce().getDeviceBuffer());
        kernel1.setArg<cl::Buffer>(2, integration.getPosDelta().getDeviceBuffer());
peastman's avatar
peastman committed
7565
        kernel1.setArg<cl::Buffer>(3, params.getDeviceBuffer());
7566
7567
7568
        kernel1.setArg<cl::Buffer>(4, integration.getStepSize().getDeviceBuffer());
        kernel1.setArg<cl::Buffer>(5, integration.getRandom().getDeviceBuffer());
        kernel2.setArg<cl::Buffer>(0, cl.getPosq().getDeviceBuffer());
7569
7570
7571
7572
        setPosqCorrectionArg(cl, kernel2, 1);
        kernel2.setArg<cl::Buffer>(2, integration.getPosDelta().getDeviceBuffer());
        kernel2.setArg<cl::Buffer>(3, cl.getVelm().getDeviceBuffer());
        kernel2.setArg<cl::Buffer>(4, integration.getStepSize().getDeviceBuffer());
7573
    }
7574
7575
7576
    double temperature = integrator.getTemperature();
    double friction = integrator.getFriction();
    double stepSize = integrator.getStepSize();
7577
    cl.getIntegrationUtilities().setNextStepSize(stepSize);
7578
7579
7580
7581
    if (temperature != prevTemp || friction != prevFriction || stepSize != prevStepSize) {
        // Calculate the integration parameters.

        double kT = BOLTZ*temperature;
7582
7583
7584
        double vscale = exp(-stepSize*friction);
        double fscale = (friction == 0 ? stepSize : (1-vscale)/friction);
        double noisescale = sqrt(kT*(1-vscale*vscale));
peastman's avatar
peastman committed
7585
7586
7587
7588
7589
        vector<cl_double> p(params.getSize());
        p[0] = vscale;
        p[1] = fscale;
        p[2] = noisescale;
        params.upload(p, true, true);
7590
7591
7592
7593
7594
7595
7596
        prevTemp = temperature;
        prevFriction = friction;
        prevStepSize = stepSize;
    }

    // Call the first integration kernel.

7597
    kernel1.setArg<cl_uint>(6, integration.prepareRandomNumbers(cl.getPaddedNumAtoms()));
7598
7599
7600
7601
    cl.executeKernel(kernel1, numAtoms);

    // Apply constraints.

7602
    integration.applyConstraints(integrator.getConstraintTolerance());
7603
7604
7605
7606

    // Call the second integration kernel.

    cl.executeKernel(kernel2, numAtoms);
7607
    integration.computeVirtualSites();
7608
7609
7610
7611
7612

    // Update the time and step count.

    cl.setTime(cl.getTime()+stepSize);
    cl.setStepCount(cl.getStepCount()+1);
7613
    cl.reorderAtoms();
7614
7615
7616
7617
7618
7619
    
    // Reduce UI lag.
    
#ifdef WIN32
    cl.getQueue().flush();
#endif
7620
}
7621

7622
7623
7624
7625
double OpenCLIntegrateLangevinStepKernel::computeKineticEnergy(ContextImpl& context, const LangevinIntegrator& integrator) {
    return cl.getIntegrationUtilities().computeKineticEnergy(0.5*integrator.getStepSize());
}

7626
7627
7628
7629
OpenCLIntegrateBrownianStepKernel::~OpenCLIntegrateBrownianStepKernel() {
}

void OpenCLIntegrateBrownianStepKernel::initialize(const System& system, const BrownianIntegrator& integrator) {
7630
    cl.getPlatformData().initializeContexts(system);
7631
7632
    cl.getIntegrationUtilities().initRandomNumberGenerator(integrator.getRandomNumberSeed());
    map<string, string> defines;
7633
    defines["NUM_ATOMS"] = cl.intToString(cl.getNumAtoms());
7634
    cl::Program program = cl.createProgram(OpenCLKernelSources::brownian, defines, "");
7635
7636
7637
7638
7639
7640
7641
7642
7643
7644
7645
7646
    kernel1 = cl::Kernel(program, "integrateBrownianPart1");
    kernel2 = cl::Kernel(program, "integrateBrownianPart2");
    prevStepSize = -1.0;
}

void OpenCLIntegrateBrownianStepKernel::execute(ContextImpl& context, const BrownianIntegrator& integrator) {
    OpenCLIntegrationUtilities& integration = cl.getIntegrationUtilities();
    int numAtoms = cl.getNumAtoms();
    if (!hasInitializedKernels) {
        hasInitializedKernels = true;
        kernel1.setArg<cl::Buffer>(2, cl.getForce().getDeviceBuffer());
        kernel1.setArg<cl::Buffer>(3, integration.getPosDelta().getDeviceBuffer());
7647
7648
        kernel1.setArg<cl::Buffer>(4, cl.getVelm().getDeviceBuffer());
        kernel1.setArg<cl::Buffer>(5, integration.getRandom().getDeviceBuffer());
7649
        kernel2.setArg<cl::Buffer>(1, cl.getPosq().getDeviceBuffer());
7650
7651
7652
        setPosqCorrectionArg(cl, kernel2, 2);
        kernel2.setArg<cl::Buffer>(3, cl.getVelm().getDeviceBuffer());
        kernel2.setArg<cl::Buffer>(4, integration.getPosDelta().getDeviceBuffer());
7653
7654
7655
7656
7657
7658
    }
    double temperature = integrator.getTemperature();
    double friction = integrator.getFriction();
    double stepSize = integrator.getStepSize();
    if (temperature != prevTemp || friction != prevFriction || stepSize != prevStepSize) {
        double tau = (friction == 0.0 ? 0.0 : 1.0/friction);
7659
7660
7661
7662
7663
7664
7665
7666
7667
7668
        if (cl.getUseDoublePrecision() || cl.getUseMixedPrecision()) {
            kernel1.setArg<cl_double>(0, tau*stepSize);
            kernel1.setArg<cl_double>(1, sqrt(2.0f*BOLTZ*temperature*stepSize*tau));
            kernel2.setArg<cl_double>(0, 1.0/stepSize);
        }
        else {
            kernel1.setArg<cl_float>(0, (cl_float) (tau*stepSize));
            kernel1.setArg<cl_float>(1, (cl_float) (sqrt(2.0f*BOLTZ*temperature*stepSize*tau)));
            kernel2.setArg<cl_float>(0, (cl_float) (1.0/stepSize));
        }
7669
7670
7671
7672
7673
7674
7675
        prevTemp = temperature;
        prevFriction = friction;
        prevStepSize = stepSize;
    }

    // Call the first integration kernel.

7676
    kernel1.setArg<cl_uint>(6, integration.prepareRandomNumbers(cl.getPaddedNumAtoms()));
7677
7678
7679
7680
7681
7682
7683
7684
7685
    cl.executeKernel(kernel1, numAtoms);

    // Apply constraints.

    integration.applyConstraints(integrator.getConstraintTolerance());

    // Call the second integration kernel.

    cl.executeKernel(kernel2, numAtoms);
7686
    integration.computeVirtualSites();
7687
7688
7689
7690
7691

    // Update the time and step count.

    cl.setTime(cl.getTime()+stepSize);
    cl.setStepCount(cl.getStepCount()+1);
7692
    cl.reorderAtoms();
7693
7694
7695
7696
7697
7698
    
    // Reduce UI lag.
    
#ifdef WIN32
    cl.getQueue().flush();
#endif
7699
}
7700

7701
7702
7703
7704
double OpenCLIntegrateBrownianStepKernel::computeKineticEnergy(ContextImpl& context, const BrownianIntegrator& integrator) {
    return cl.getIntegrationUtilities().computeKineticEnergy(0);
}

7705
7706
7707
7708
OpenCLIntegrateVariableVerletStepKernel::~OpenCLIntegrateVariableVerletStepKernel() {
}

void OpenCLIntegrateVariableVerletStepKernel::initialize(const System& system, const VariableVerletIntegrator& integrator) {
7709
    cl.getPlatformData().initializeContexts(system);
7710
    cl::Program program = cl.createProgram(OpenCLKernelSources::verlet, "");
7711
7712
7713
    kernel1 = cl::Kernel(program, "integrateVerletPart1");
    kernel2 = cl::Kernel(program, "integrateVerletPart2");
    selectSizeKernel = cl::Kernel(program, "selectVerletStepSize");
7714
    blockSize = min(min(256, system.getNumParticles()), (int) selectSizeKernel.getWorkGroupInfo<CL_KERNEL_WORK_GROUP_SIZE>(cl.getDevice()));
7715
7716
}

7717
double OpenCLIntegrateVariableVerletStepKernel::execute(ContextImpl& context, const VariableVerletIntegrator& integrator, double maxTime) {
7718
    OpenCLIntegrationUtilities& integration = cl.getIntegrationUtilities();
7719
    int numAtoms = cl.getNumAtoms();
7720
    bool useDouble = cl.getUseDoublePrecision() || cl.getUseMixedPrecision();
7721
7722
7723
    if (!hasInitializedKernels) {
        hasInitializedKernels = true;
        kernel1.setArg<cl_int>(0, numAtoms);
7724
        kernel1.setArg<cl::Buffer>(1, cl.getIntegrationUtilities().getStepSize().getDeviceBuffer());
7725
        kernel1.setArg<cl::Buffer>(2, cl.getPosq().getDeviceBuffer());
7726
7727
7728
7729
        setPosqCorrectionArg(cl, kernel1, 3);
        kernel1.setArg<cl::Buffer>(4, cl.getVelm().getDeviceBuffer());
        kernel1.setArg<cl::Buffer>(5, cl.getForce().getDeviceBuffer());
        kernel1.setArg<cl::Buffer>(6, integration.getPosDelta().getDeviceBuffer());
7730
        kernel2.setArg<cl_int>(0, numAtoms);
7731
        kernel2.setArg<cl::Buffer>(1, cl.getIntegrationUtilities().getStepSize().getDeviceBuffer());
7732
        kernel2.setArg<cl::Buffer>(2, cl.getPosq().getDeviceBuffer());
7733
7734
7735
        setPosqCorrectionArg(cl, kernel2, 3);
        kernel2.setArg<cl::Buffer>(4, cl.getVelm().getDeviceBuffer());
        kernel2.setArg<cl::Buffer>(5, integration.getPosDelta().getDeviceBuffer());
7736
        selectSizeKernel.setArg<cl_int>(0, numAtoms);
7737
        selectSizeKernel.setArg<cl::Buffer>(3, cl.getIntegrationUtilities().getStepSize().getDeviceBuffer());
7738
7739
        selectSizeKernel.setArg<cl::Buffer>(4, cl.getVelm().getDeviceBuffer());
        selectSizeKernel.setArg<cl::Buffer>(5, cl.getForce().getDeviceBuffer());
7740
7741
        int elementSize = (useDouble ? sizeof(cl_double) : sizeof(cl_float));
        selectSizeKernel.setArg(6, blockSize*elementSize, NULL);
7742
7743
7744
7745
    }

    // Select the step size to use.

7746
    double maxStepSize = maxTime-cl.getTime();
7747
7748
    if (integrator.getMaximumStepSize() > 0)
        maxStepSize = min(integrator.getMaximumStepSize(), maxStepSize);
7749
7750
7751
7752
7753
7754
7755
7756
7757
    float maxStepSizeFloat = (float) maxStepSize;
    if (useDouble) {
        selectSizeKernel.setArg<cl_double>(1, maxStepSize);
        selectSizeKernel.setArg<cl_double>(2, integrator.getErrorTolerance());
    }
    else {
        selectSizeKernel.setArg<cl_float>(1, maxStepSizeFloat);
        selectSizeKernel.setArg<cl_float>(2, (cl_float) integrator.getErrorTolerance());
    }
7758
7759
7760
7761
7762
7763
7764
7765
7766
7767
7768
7769
7770
    cl.executeKernel(selectSizeKernel, blockSize, blockSize);

    // Call the first integration kernel.

    cl.executeKernel(kernel1, numAtoms);

    // Apply constraints.

    integration.applyConstraints(integrator.getConstraintTolerance());

    // Call the second integration kernel.

    cl.executeKernel(kernel2, numAtoms);
7771
    integration.computeVirtualSites();
7772
7773
7774
7775
7776
7777
    
    // Reduce UI lag.
    
#ifdef WIN32
    cl.getQueue().flush();
#endif
7778
7779
7780

    // Update the time and step count.

7781
7782
    double dt = cl.getIntegrationUtilities().getLastStepSize();
    double time = cl.getTime()+dt;
7783
7784
7785
7786
7787
7788
7789
7790
    if (useDouble) {
        if (dt == maxStepSize)
            time = maxTime; // Avoid round-off error
    }
    else {
        if (dt == maxStepSizeFloat)
            time = maxTime; // Avoid round-off error
    }
7791
7792
    cl.setTime(time);
    cl.setStepCount(cl.getStepCount()+1);
7793
    cl.reorderAtoms();
7794
    return dt;
7795
7796
}

7797
7798
7799
7800
double OpenCLIntegrateVariableVerletStepKernel::computeKineticEnergy(ContextImpl& context, const VariableVerletIntegrator& integrator) {
    return cl.getIntegrationUtilities().computeKineticEnergy(0.5*integrator.getStepSize());
}

7801
void OpenCLIntegrateVariableLangevinStepKernel::initialize(const System& system, const VariableLangevinIntegrator& integrator) {
7802
    cl.getPlatformData().initializeContexts(system);
7803
7804
    cl.getIntegrationUtilities().initRandomNumberGenerator(integrator.getRandomNumberSeed());
    map<string, string> defines;
7805
7806
    defines["NUM_ATOMS"] = cl.intToString(cl.getNumAtoms());
    defines["PADDED_NUM_ATOMS"] = cl.intToString(cl.getPaddedNumAtoms());
7807
    cl::Program program = cl.createProgram(OpenCLKernelSources::langevin, defines, "");
7808
7809
7810
    kernel1 = cl::Kernel(program, "integrateLangevinPart1");
    kernel2 = cl::Kernel(program, "integrateLangevinPart2");
    selectSizeKernel = cl::Kernel(program, "selectLangevinStepSize");
peastman's avatar
peastman committed
7811
    params.initialize(cl, 3, cl.getUseDoublePrecision() || cl.getUseMixedPrecision() ? sizeof(cl_double) : sizeof(cl_float), "langevinParams");
Peter Eastman's avatar
Peter Eastman committed
7812
    blockSize = min(256, system.getNumParticles());
peastman's avatar
peastman committed
7813
    blockSize = max(blockSize, params.getSize());
7814
    blockSize = min(blockSize, (int) selectSizeKernel.getWorkGroupInfo<CL_KERNEL_WORK_GROUP_SIZE>(cl.getDevice()));
7815
7816
}

7817
double OpenCLIntegrateVariableLangevinStepKernel::execute(ContextImpl& context, const VariableLangevinIntegrator& integrator, double maxTime) {
7818
    OpenCLIntegrationUtilities& integration = cl.getIntegrationUtilities();
7819
    int numAtoms = cl.getNumAtoms();
7820
    bool useDouble = cl.getUseDoublePrecision() || cl.getUseMixedPrecision();
7821
7822
    if (!hasInitializedKernels) {
        hasInitializedKernels = true;
7823
7824
7825
        kernel1.setArg<cl::Buffer>(0, cl.getVelm().getDeviceBuffer());
        kernel1.setArg<cl::Buffer>(1, cl.getForce().getDeviceBuffer());
        kernel1.setArg<cl::Buffer>(2, integration.getPosDelta().getDeviceBuffer());
peastman's avatar
peastman committed
7826
        kernel1.setArg<cl::Buffer>(3, params.getDeviceBuffer());
7827
7828
7829
        kernel1.setArg<cl::Buffer>(4, integration.getStepSize().getDeviceBuffer());
        kernel1.setArg<cl::Buffer>(5, integration.getRandom().getDeviceBuffer());
        kernel2.setArg<cl::Buffer>(0, cl.getPosq().getDeviceBuffer());
7830
7831
7832
7833
        setPosqCorrectionArg(cl, kernel2, 1);
        kernel2.setArg<cl::Buffer>(2, integration.getPosDelta().getDeviceBuffer());
        kernel2.setArg<cl::Buffer>(3, cl.getVelm().getDeviceBuffer());
        kernel2.setArg<cl::Buffer>(4, integration.getStepSize().getDeviceBuffer());
7834
        selectSizeKernel.setArg<cl::Buffer>(4, integration.getStepSize().getDeviceBuffer());
7835
7836
        selectSizeKernel.setArg<cl::Buffer>(5, cl.getVelm().getDeviceBuffer());
        selectSizeKernel.setArg<cl::Buffer>(6, cl.getForce().getDeviceBuffer());
peastman's avatar
peastman committed
7837
        selectSizeKernel.setArg<cl::Buffer>(7, params.getDeviceBuffer());
7838
        int elementSize = (useDouble ? sizeof(cl_double) : sizeof(cl_float));
peastman's avatar
peastman committed
7839
        selectSizeKernel.setArg(8, params.getSize()*elementSize, NULL);
7840
        selectSizeKernel.setArg(9, blockSize*elementSize, NULL);
7841
7842
7843
7844
    }

    // Select the step size to use.

7845
    double maxStepSize = maxTime-cl.getTime();
7846
7847
    if (integrator.getMaximumStepSize() > 0)
        maxStepSize = min(integrator.getMaximumStepSize(), maxStepSize);
7848
7849
7850
7851
    float maxStepSizeFloat = (float) maxStepSize;
    if (useDouble) {
        selectSizeKernel.setArg<cl_double>(0, maxStepSize);
        selectSizeKernel.setArg<cl_double>(1, integrator.getErrorTolerance());
7852
        selectSizeKernel.setArg<cl_double>(2, integrator.getFriction());
7853
7854
7855
7856
7857
        selectSizeKernel.setArg<cl_double>(3, BOLTZ*integrator.getTemperature());
    }
    else {
        selectSizeKernel.setArg<cl_float>(0, maxStepSizeFloat);
        selectSizeKernel.setArg<cl_float>(1, (cl_float) integrator.getErrorTolerance());
7858
        selectSizeKernel.setArg<cl_float>(2, (cl_float) integrator.getFriction());
7859
7860
        selectSizeKernel.setArg<cl_float>(3, (cl_float) (BOLTZ*integrator.getTemperature()));
    }
7861
7862
7863
7864
    cl.executeKernel(selectSizeKernel, blockSize, blockSize);

    // Call the first integration kernel.

7865
    kernel1.setArg<cl_uint>(6, integration.prepareRandomNumbers(cl.getPaddedNumAtoms()));
7866
7867
7868
7869
7870
7871
7872
7873
7874
    cl.executeKernel(kernel1, numAtoms);

    // Apply constraints.

    integration.applyConstraints(integrator.getConstraintTolerance());

    // Call the second integration kernel.

    cl.executeKernel(kernel2, numAtoms);
7875
    integration.computeVirtualSites();
7876
7877
7878
7879
7880
7881
    
    // Reduce UI lag.
    
#ifdef WIN32
    cl.getQueue().flush();
#endif
7882
7883
7884

    // Update the time and step count.

7885
7886
    double dt = cl.getIntegrationUtilities().getLastStepSize();
    double time = cl.getTime()+dt;
7887
7888
7889
7890
7891
7892
7893
7894
    if (useDouble) {
        if (dt == maxStepSize)
            time = maxTime; // Avoid round-off error
    }
    else {
        if (dt == maxStepSizeFloat)
            time = maxTime; // Avoid round-off error
    }
7895
7896
    cl.setTime(time);
    cl.setStepCount(cl.getStepCount()+1);
7897
    cl.reorderAtoms();
7898
    return dt;
7899
7900
}

7901
7902
7903
7904
double OpenCLIntegrateVariableLangevinStepKernel::computeKineticEnergy(ContextImpl& context, const VariableLangevinIntegrator& integrator) {
    return cl.getIntegrationUtilities().computeKineticEnergy(0.5*integrator.getStepSize());
}

7905
7906
class OpenCLIntegrateCustomStepKernel::ReorderListener : public OpenCLContext::ReorderListener {
public:
7907
    ReorderListener(OpenCLContext& cl, vector<OpenCLArray>& perDofValues, vector<vector<mm_float4> >& localPerDofValuesFloat, vector<vector<mm_double4> >& localPerDofValuesDouble, vector<bool>& deviceValuesAreCurrent) :
7908
            cl(cl), perDofValues(perDofValues), localPerDofValuesFloat(localPerDofValuesFloat), localPerDofValuesDouble(localPerDofValuesDouble), deviceValuesAreCurrent(deviceValuesAreCurrent) {
7909
7910
7911
7912
7913
7914
7915
7916
        int numAtoms = cl.getNumAtoms();
        lastAtomOrder.resize(numAtoms);
        for (int i = 0; i < numAtoms; i++)
            lastAtomOrder[i] = cl.getAtomIndex()[i];
    }
    void execute() {
        // Reorder the per-DOF variables to reflect the new atom order.

7917
        if (perDofValues.size() == 0)
7918
            return;
7919
        int numAtoms = cl.getNumAtoms();
7920
        const vector<int>& order = cl.getAtomIndex();
7921
7922
7923
7924
7925
7926
7927
7928
7929
7930
        for (int index = 0; index < perDofValues.size(); index++) {
            if (cl.getUseDoublePrecision() || cl.getUseMixedPrecision()) {
                if (deviceValuesAreCurrent[index])
                    perDofValues[index].download(localPerDofValuesDouble[index]);
                vector<mm_double4> swap(numAtoms);
                for (int i = 0; i < numAtoms; i++)
                    swap[lastAtomOrder[i]] = localPerDofValuesDouble[index][i];
                for (int i = 0; i < numAtoms; i++)
                    localPerDofValuesDouble[index][i] = swap[order[i]];
                perDofValues[index].upload(localPerDofValuesDouble[index]);
7931
            }
7932
7933
7934
7935
7936
7937
7938
7939
7940
            else {
                if (deviceValuesAreCurrent[index])
                    perDofValues[index].download(localPerDofValuesFloat[index]);
                vector<mm_float4> swap(numAtoms);
                for (int i = 0; i < numAtoms; i++)
                    swap[lastAtomOrder[i]] = localPerDofValuesFloat[index][i];
                for (int i = 0; i < numAtoms; i++)
                    localPerDofValuesFloat[index][i] = swap[order[i]];
                perDofValues[index].upload(localPerDofValuesFloat[index]);
7941
            }
7942
            deviceValuesAreCurrent[index] = true;
7943
        }
7944
7945
7946
7947
7948
        for (int i = 0; i < numAtoms; i++)
            lastAtomOrder[i] = order[i];
    }
private:
    OpenCLContext& cl;
7949
7950
7951
7952
    vector<OpenCLArray>& perDofValues;
    vector<vector<mm_float4> >& localPerDofValuesFloat;
    vector<vector<mm_double4> >& localPerDofValuesDouble;
    vector<bool>& deviceValuesAreCurrent;
Peter Eastman's avatar
Peter Eastman committed
7953
    vector<int> lastAtomOrder;
7954
7955
};

7956
7957
7958
7959
7960
7961
7962
7963
7964
7965
7966
7967
7968
7969
7970
7971
7972
7973
7974
7975
7976
class OpenCLIntegrateCustomStepKernel::DerivFunction : public CustomFunction {
public:
    DerivFunction(map<string, double>& energyParamDerivs, const string& param) : energyParamDerivs(energyParamDerivs), param(param) {
    }
    int getNumArguments() const {
        return 0;
    }
    double evaluate(const double* arguments) const {
        return energyParamDerivs[param];
    }
    double evaluateDerivative(const double* arguments, const int* derivOrder) const {
        return 0;
    }
    CustomFunction* clone() const {
        return new DerivFunction(energyParamDerivs, param);
    }
private:
    map<string, double>& energyParamDerivs;
    string param;
};

7977
7978
7979
7980
void OpenCLIntegrateCustomStepKernel::initialize(const System& system, const CustomIntegrator& integrator) {
    cl.getPlatformData().initializeContexts(system);
    cl.getIntegrationUtilities().initRandomNumberGenerator(integrator.getRandomNumberSeed());
    numGlobalVariables = integrator.getNumGlobalVariables();
7981
    int elementSize = (cl.getUseDoublePrecision() || cl.getUseMixedPrecision() ? sizeof(double) : sizeof(float));
7982
    sumBuffer.initialize(cl, system.getNumParticles(), elementSize, "sumBuffer");
peastman's avatar
peastman committed
7983
    summedValue.initialize(cl, 1, elementSize, "summedValue");
7984
7985
7986
7987
7988
7989
7990
7991
    perDofValues.resize(integrator.getNumPerDofVariables());
    localPerDofValuesFloat.resize(perDofValues.size());
    localPerDofValuesDouble.resize(perDofValues.size());
    for (int i = 0; i < perDofValues.size(); i++)
        perDofValues[i].initialize(cl, system.getNumParticles(), 4*elementSize, "perDofVariables");
    localValuesAreCurrent.resize(integrator.getNumPerDofVariables(), false);
    deviceValuesAreCurrent.resize(integrator.getNumPerDofVariables(), false);
    cl.addReorderListener(new ReorderListener(cl, perDofValues, localPerDofValuesFloat, localPerDofValuesDouble, deviceValuesAreCurrent));
7992
7993
7994
    SimTKOpenMMUtilities::setRandomNumberSeed(integrator.getRandomNumberSeed());
}

7995
string OpenCLIntegrateCustomStepKernel::createPerDofComputation(const string& variable, const Lepton::ParsedExpression& expr, CustomIntegrator& integrator,
7996
        const string& forceName, const string& energyName, vector<const TabulatedFunction*>& functions, vector<pair<string, string> >& functionNames) {
7997
    string tempType = (cl.getSupportsDoublePrecision() ? "double3" : "float3");
7998
    string convert = (cl.getSupportsDoublePrecision() ? "convert_double3" : "");
7999
    map<string, Lepton::ParsedExpression> expressions;
8000
    expressions[tempType+" tempResult = "] = expr;
8001
    map<string, string> variables;
8002
8003
8004
8005
8006
    variables["x"] = convert+"(position.xyz)";
    variables["v"] = convert+"(velocity.xyz)";
    variables[forceName] = convert+"(f.xyz)";
    variables["gaussian"] = convert+"(gaussian.xyz)";
    variables["uniform"] = convert+"(uniform.xyz)";
8007
8008
    variables["m"] = "mass";
    variables["dt"] = "stepSize";
8009
    if (energyName != "")
Peter Eastman's avatar
Peter Eastman committed
8010
        variables[energyName] = "energy";
8011
    for (int i = 0; i < integrator.getNumGlobalVariables(); i++)
8012
        variables[integrator.getGlobalVariableName(i)] = "globals["+cl.intToString(globalVariableIndex[i])+"]";
8013
    for (int i = 0; i < integrator.getNumPerDofVariables(); i++)
8014
        variables[integrator.getPerDofVariableName(i)] = convert+"(perDof"+cl.intToString(i)+")";
8015
    for (int i = 0; i < (int) parameterNames.size(); i++)
8016
        variables[parameterNames[i]] = "globals["+cl.intToString(parameterVariableIndex[i])+"]";
8017
8018
    vector<pair<ExpressionTreeNode, string> > variableNodes;
    findExpressionsForDerivs(expr.getRootNode(), variableNodes);
peastman's avatar
peastman committed
8019
8020
    for (auto& var : variables)
        variableNodes.push_back(make_pair(ExpressionTreeNode(new Operation::Variable(var.first)), var.second));
8021
    string result = cl.getExpressionUtilities().createExpressions(expressions, variableNodes, functions, functionNames, "temp", tempType);
8022
8023
8024
8025
8026
8027
8028
8029
8030
8031
8032
8033
8034
    if (variable == "x")
        result += "position.x = tempResult.x; position.y = tempResult.y; position.z = tempResult.z;\n";
    else if (variable == "v")
        result += "velocity.x = tempResult.x; velocity.y = tempResult.y; velocity.z = tempResult.z;\n";
    else if (variable == "")
        result += "sum[index] = tempResult.x+tempResult.y+tempResult.z;\n";
    else {
        for (int i = 0; i < integrator.getNumPerDofVariables(); i++)
            if (variable == integrator.getPerDofVariableName(i)) {
                string varName = "perDof"+cl.intToString(i);
                result += varName+".x = tempResult.x; "+varName+".y = tempResult.y; "+varName+".z = tempResult.z;\n";
            }
    }
8035
    return result;
8036
8037
}

8038
void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid) {
8039
8040
8041
    OpenCLIntegrationUtilities& integration = cl.getIntegrationUtilities();
    int numAtoms = cl.getNumAtoms();
    int numSteps = integrator.getNumComputations();
8042
    bool useDouble = cl.getUseDoublePrecision() || cl.getUseMixedPrecision();
8043
    string tempType = (cl.getSupportsDoublePrecision() ? "double3" : "float3");
8044
    string perDofType = (useDouble ? "double4" : "float4");
8045
8046
    if (!hasInitializedKernels) {
        hasInitializedKernels = true;
8047
8048
8049
8050
        
        // Initialize various data structures.
        
        const map<string, double>& params = context.getParameters();
peastman's avatar
peastman committed
8051
8052
        for (auto& param : params)
            parameterNames.push_back(param.first);
8053
        kernels.resize(integrator.getNumComputations());
8054
8055
        requiredGaussian.resize(integrator.getNumComputations(), 0);
        requiredUniform.resize(integrator.getNumComputations(), 0);
8056
8057
8058
8059
        needsGlobals.resize(numSteps, false);
        globalExpressions.resize(numSteps);
        stepType.resize(numSteps);
        stepTarget.resize(numSteps);
8060
        merged.resize(numSteps, false);
8061
        modifiesParameters = false;
8062
8063
8064
        sumWorkGroupSize = cl.getDevice().getInfo<CL_DEVICE_MAX_WORK_GROUP_SIZE>();
        if (sumWorkGroupSize > 512)
            sumWorkGroupSize = 512;
8065
        map<string, string> defines;
8066
        defines["NUM_ATOMS"] = cl.intToString(cl.getNumAtoms());
8067
        defines["WORK_GROUP_SIZE"] = cl.intToString(sumWorkGroupSize);
8068
8069
8070
8071
8072
8073
8074

        // Record the tabulated functions.

        map<string, Lepton::CustomFunction*> functions;
        vector<pair<string, string> > functionNames;
        vector<const TabulatedFunction*> functionList;
        vector<string> tableTypes;
peastman's avatar
peastman committed
8075
        tabulatedFunctions.resize(integrator.getNumTabulatedFunctions());
8076
8077
8078
8079
8080
8081
8082
8083
        for (int i = 0; i < integrator.getNumTabulatedFunctions(); i++) {
            functionList.push_back(&integrator.getTabulatedFunction(i));
            string name = integrator.getTabulatedFunctionName(i);
            string arrayName = "table"+cl.intToString(i);
            functionNames.push_back(make_pair(name, arrayName));
            functions[name] = createReferenceTabulatedFunction(integrator.getTabulatedFunction(i));
            int width;
            vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(integrator.getTabulatedFunction(i), width);
peastman's avatar
peastman committed
8084
8085
            tabulatedFunctions[i].initialize<float>(cl, f.size(), "TabulatedFunction");
            tabulatedFunctions[i].upload(f);
8086
8087
8088
8089
8090
8091
            if (width == 1)
                tableTypes.push_back("float");
            else
                tableTypes.push_back("float"+cl.intToString(width));
        }

8092
8093
8094
8095
8096
        // Record information about all the computation steps.

        vector<string> variable(numSteps);
        vector<int> forceGroup;
        vector<vector<Lepton::ParsedExpression> > expression;
8097
        CustomIntegratorUtilities::analyzeComputations(context, integrator, expression, comparisons, blockEnd, invalidatesForces, needsForces, needsEnergy, computeBothForceAndEnergy, forceGroup, functions);
8098
8099
8100
        for (int step = 0; step < numSteps; step++) {
            string expr;
            integrator.getComputationStep(step, stepType[step], variable[step], expr);
8101
            if (stepType[step] == CustomIntegrator::WhileBlockStart)
8102
                blockEnd[blockEnd[step]] = step; // Record where to branch back to.
8103
            if (stepType[step] == CustomIntegrator::ComputeGlobal || stepType[step] == CustomIntegrator::IfBlockStart || stepType[step] == CustomIntegrator::WhileBlockStart)
peastman's avatar
peastman committed
8104
8105
                for (auto& expr : expression[step])
                    globalExpressions[step].push_back(ParsedExpression(replaceDerivFunctions(expr.getRootNode(), context)).createCompiledExpression());
8106
8107
        }
        for (int step = 0; step < numSteps; step++) {
peastman's avatar
peastman committed
8108
8109
            for (auto& expr : globalExpressions[step])
                expressionSet.registerExpression(expr);
8110
8111
        }
        
8112
        // Record the indices for variables in the CompiledExpressionSet.
8113
        
8114
8115
8116
8117
8118
        gaussianVariableIndex = expressionSet.getVariableIndex("gaussian");
        uniformVariableIndex = expressionSet.getVariableIndex("uniform");
        dtVariableIndex = expressionSet.getVariableIndex("dt");
        for (int i = 0; i < integrator.getNumGlobalVariables(); i++)
            globalVariableIndex.push_back(expressionSet.getVariableIndex(integrator.getGlobalVariableName(i)));
peastman's avatar
peastman committed
8119
8120
        for (auto& name : parameterNames)
            parameterVariableIndex.push_back(expressionSet.getVariableIndex(name));
8121
8122
8123
8124

        // Record the variable names and flags for the force and energy in each step.

        forceGroupFlags.resize(numSteps, -1);
8125
        vector<string> forceGroupName;
8126
        vector<string> energyGroupName;
8127
        for (int i = 0; i < 32; i++) {
8128
8129
8130
8131
8132
8133
            stringstream fname;
            fname << "f" << i;
            forceGroupName.push_back(fname.str());
            stringstream ename;
            ename << "energy" << i;
            energyGroupName.push_back(ename.str());
8134
8135
        }
        vector<string> forceName(numSteps, "f");
8136
        vector<string> energyName(numSteps, "energy");
8137
        stepEnergyVariableIndex.resize(numSteps, expressionSet.getVariableIndex("energy"));
8138
        for (int step = 0; step < numSteps; step++) {
8139
8140
8141
8142
8143
8144
8145
8146
8147
8148
            if (needsForces[step] && forceGroup[step] > -1)
                forceName[step] = forceGroupName[forceGroup[step]];
            if (needsEnergy[step] && forceGroup[step] > -1) {
                energyName[step] = energyGroupName[forceGroup[step]];
                stepEnergyVariableIndex[step] = expressionSet.getVariableIndex(energyName[step]);
            }
            if (forceGroup[step] > -1)
                forceGroupFlags[step] = 1<<forceGroup[step];
            if (forceGroupFlags[step] == -2 && step > 0)
                forceGroupFlags[step] = forceGroupFlags[step-1];
peastman's avatar
peastman committed
8149
8150
8151
8152
            if (forceGroupFlags[step] != -2 && savedForces.find(forceGroupFlags[step]) == savedForces.end()) {
                savedForces[forceGroupFlags[step]] = OpenCLArray();
                savedForces[forceGroupFlags[step]].initialize(cl, cl.getForce().getSize(), cl.getForce().getElementSize(), "savedForces");
            }
8153
8154
8155
8156
        }
        
        // Allocate space for storing global values, both on the host and the device.
        
peastman's avatar
peastman committed
8157
        localGlobalValues.resize(expressionSet.getNumVariables());
8158
        int elementSize = (cl.getUseDoublePrecision() || cl.getUseMixedPrecision() ? sizeof(double) : sizeof(float));
peastman's avatar
peastman committed
8159
        globalValues.initialize(cl, expressionSet.getNumVariables(), elementSize, "globalValues");
8160
        for (int i = 0; i < integrator.getNumGlobalVariables(); i++) {
peastman's avatar
peastman committed
8161
            localGlobalValues[globalVariableIndex[i]] = initialGlobalVariables[i];
8162
8163
8164
8165
            expressionSet.setVariable(globalVariableIndex[i], initialGlobalVariables[i]);
        }
        for (int i = 0; i < (int) parameterVariableIndex.size(); i++) {
            double value = context.getParameter(parameterNames[i]);
peastman's avatar
peastman committed
8166
            localGlobalValues[parameterVariableIndex[i]] = value;
8167
8168
            expressionSet.setVariable(parameterVariableIndex[i], value);
        }
8169
        int numContextParams = context.getParameters().size();
peastman's avatar
peastman committed
8170
        localPerDofEnergyParamDerivs.resize(numContextParams);
peastman's avatar
peastman committed
8171
        perDofEnergyParamDerivs.initialize(cl, max(1, numContextParams), elementSize, "perDofEnergyParamDerivs");
8172
8173
8174
8175
8176
8177
8178
8179
8180
8181
        
        // Record information about the targets of steps that will be stored in global variables.
        
        for (int step = 0; step < numSteps; step++) {
            if (stepType[step] == CustomIntegrator::ComputeGlobal || stepType[step] == CustomIntegrator::ComputeSum) {
                if (variable[step] == "dt")
                    stepTarget[step].type = DT;
                for (int i = 0; i < integrator.getNumGlobalVariables(); i++)
                    if (variable[step] == integrator.getGlobalVariableName(i))
                        stepTarget[step].type = VARIABLE;
peastman's avatar
peastman committed
8182
8183
                for (auto& name : parameterNames)
                    if (variable[step] == name) {
8184
8185
                        stepTarget[step].type = PARAMETER;
                        modifiesParameters = true;
8186
                    }
8187
8188
8189
8190
8191
8192
8193
8194
8195
8196
8197
                stepTarget[step].variableIndex = expressionSet.getVariableIndex(variable[step]);
            }
        }

        // Identify which per-DOF steps are going to require global variables or context parameters.

        for (int step = 0; step < numSteps; step++) {
            if (stepType[step] == CustomIntegrator::ComputePerDof || stepType[step] == CustomIntegrator::ComputeSum) {
                for (int i = 0; i < integrator.getNumGlobalVariables(); i++)
                    if (usesVariable(expression[step][0], integrator.getGlobalVariableName(i)))
                        needsGlobals[step] = true;
peastman's avatar
peastman committed
8198
8199
                for (auto& name : parameterNames)
                    if (usesVariable(expression[step][0], name))
8200
                        needsGlobals[step] = true;
8201
            }
8202
8203
8204
8205
        }
        
        // Determine how each step will represent the position (as just a value, or a value plus a delta).
        
peastman's avatar
peastman committed
8206
        hasAnyConstraints = (context.getSystem().getNumConstraints() > 0);
8207
8208
        vector<bool> storePosAsDelta(numSteps, false);
        vector<bool> loadPosAsDelta(numSteps, false);
peastman's avatar
peastman committed
8209
8210
8211
8212
8213
        if (hasAnyConstraints) {
            bool beforeConstrain = false;
            for (int step = numSteps-1; step >= 0; step--) {
                if (stepType[step] == CustomIntegrator::ConstrainPositions)
                    beforeConstrain = true;
peastman's avatar
peastman committed
8214
                else if (stepType[step] == CustomIntegrator::ComputePerDof && variable[step] == "x" && beforeConstrain) {
peastman's avatar
peastman committed
8215
                    storePosAsDelta[step] = true;
peastman's avatar
peastman committed
8216
8217
                    beforeConstrain = false;
                }
peastman's avatar
peastman committed
8218
8219
8220
8221
8222
8223
8224
8225
8226
            }
            bool storedAsDelta = false;
            for (int step = 0; step < numSteps; step++) {
                loadPosAsDelta[step] = storedAsDelta;
                if (storePosAsDelta[step] == true)
                    storedAsDelta = true;
                if (stepType[step] == CustomIntegrator::ConstrainPositions)
                    storedAsDelta = false;
            }
8227
8228
        }
        
8229
8230
8231
        // Identify steps that can be merged into a single kernel.
        
        for (int step = 1; step < numSteps; step++) {
8232
            if (invalidatesForces[step-1] || forceGroupFlags[step] != forceGroupFlags[step-1])
8233
                continue;
8234
            if (stepType[step-1] == CustomIntegrator::ComputePerDof && stepType[step] == CustomIntegrator::ComputePerDof)
8235
8236
                merged[step] = true;
        }
8237
8238
8239
8240
8241
        for (int step = numSteps-1; step > 0; step--)
            if (merged[step]) {
                needsForces[step-1] = (needsForces[step] || needsForces[step-1]);
                needsEnergy[step-1] = (needsEnergy[step] || needsEnergy[step-1]);
                needsGlobals[step-1] = (needsGlobals[step] || needsGlobals[step-1]);
Peter Eastman's avatar
Peter Eastman committed
8242
                computeBothForceAndEnergy[step-1] = (computeBothForceAndEnergy[step] || computeBothForceAndEnergy[step-1]);
8243
            }
8244
        
8245
8246
8247
        // Loop over all steps and create the kernels for them.
        
        for (int step = 0; step < numSteps; step++) {
8248
            if ((stepType[step] == CustomIntegrator::ComputePerDof || stepType[step] == CustomIntegrator::ComputeSum) && !merged[step]) {
8249
8250
8251
                // Compute a per-DOF value.
                
                stringstream compute;
8252
                for (int i = 0; i < perDofValues.size(); i++)
8253
                    compute << tempType<<" perDof"<<cl.intToString(i)<<" = convert_"<<tempType<<"(perDofValues"<<cl.intToString(i)<<"[index].xyz);\n";
8254
                int numGaussian = 0, numUniform = 0;
8255
                for (int j = step; j < numSteps && (j == step || merged[j]); j++) {
8256
8257
                    numGaussian += numAtoms*usesVariable(expression[j][0], "gaussian");
                    numUniform += numAtoms*usesVariable(expression[j][0], "uniform");
8258
                    compute << "{\n";
8259
                    if (numGaussian > 0)
8260
                        compute << "float4 gaussian = gaussianValues[gaussianIndex+index];\n";
8261
                    if (numUniform > 0)
8262
                        compute << "float4 uniform = uniformValues[uniformIndex+index];\n";
8263
                    compute << createPerDofComputation(stepType[j] == CustomIntegrator::ComputePerDof ? variable[j] : "", expression[j][0], integrator, forceName[j], energyName[j], functionList, functionNames);
8264
8265
8266
                    if (variable[j] == "x") {
                        if (storePosAsDelta[j]) {
                            if (cl.getSupportsDoublePrecision())
8267
                                compute << "posDelta[index] = convert_mixed4(convert_double4(position)-convert_double4(loadPos(posq, posqCorrection, index)));\n";
8268
8269
8270
                            else
                                compute << "posDelta[index] = position-posq[index];\n";
                        }
8271
                        else
8272
                            compute << "storePos(posq, posqCorrection, index, position);\n";
8273
                    }
8274
                    else if (variable[j] == "v")
8275
                        compute << "velm[index] = convert_mixed4(velocity);\n";
8276
                    else {
8277
                        for (int i = 0; i < perDofValues.size(); i++)
8278
                            compute << "perDofValues"<<cl.intToString(i)<<"[index] = ("<<perDofType<<") (perDof"<<cl.intToString(i)<<".x, perDof"<<cl.intToString(i)<<".y, perDof"<<cl.intToString(i)<<".z, 0);\n";
8279
                    }
8280
                    if (numGaussian > 0)
8281
                        compute << "gaussianIndex += NUM_ATOMS;\n";
8282
                    if (numUniform > 0)
8283
                        compute << "uniformIndex += NUM_ATOMS;\n";
8284
                    compute << "}\n";
8285
8286
8287
8288
                }
                map<string, string> replacements;
                replacements["COMPUTE_STEP"] = compute.str();
                stringstream args;
8289
8290
8291
                for (int i = 0; i < perDofValues.size(); i++) {
                    string valueName = "perDofValues"+cl.intToString(i);
                    args << ", __global " << perDofType << "* restrict " << valueName;
8292
                }
8293
8294
                for (int i = 0; i < (int) tableTypes.size(); i++)
                    args << ", __global const " << tableTypes[i]<< "* restrict table" << i;
8295
                replacements["PARAMETER_ARGUMENTS"] = args.str();
8296
8297
8298
8299
                if (loadPosAsDelta[step])
                    defines["LOAD_POS_AS_DELTA"] = "1";
                else if (defines.find("LOAD_POS_AS_DELTA") != defines.end())
                    defines.erase("LOAD_POS_AS_DELTA");
8300
8301
8302
                cl::Program program = cl.createProgram(cl.replaceStrings(OpenCLKernelSources::customIntegratorPerDof, replacements), defines);
                cl::Kernel kernel = cl::Kernel(program, "computePerDof");
                kernels[step].push_back(kernel);
8303
8304
                requiredGaussian[step] = numGaussian;
                requiredUniform[step] = numUniform;
8305
8306
                int index = 0;
                kernel.setArg<cl::Buffer>(index++, cl.getPosq().getDeviceBuffer());
8307
                setPosqCorrectionArg(cl, kernel, index++);
8308
8309
8310
8311
                kernel.setArg<cl::Buffer>(index++, integration.getPosDelta().getDeviceBuffer());
                kernel.setArg<cl::Buffer>(index++, cl.getVelm().getDeviceBuffer());
                kernel.setArg<cl::Buffer>(index++, cl.getForce().getDeviceBuffer());
                kernel.setArg<cl::Buffer>(index++, integration.getStepSize().getDeviceBuffer());
peastman's avatar
peastman committed
8312
8313
                kernel.setArg<cl::Buffer>(index++, globalValues.getDeviceBuffer());
                kernel.setArg<cl::Buffer>(index++, sumBuffer.getDeviceBuffer());
Peter Eastman's avatar
Peter Eastman committed
8314
                index += 4;
peastman's avatar
peastman committed
8315
                kernel.setArg<cl::Buffer>(index++, perDofEnergyParamDerivs.getDeviceBuffer());
8316
8317
                for (auto& array : perDofValues)
                    kernel.setArg<cl::Memory>(index++, array.getDeviceBuffer());
peastman's avatar
peastman committed
8318
8319
                for (auto& array : tabulatedFunctions)
                    kernel.setArg<cl::Buffer>(index++, array.getDeviceBuffer());
8320
                if (stepType[step] == CustomIntegrator::ComputeSum) {
8321
8322
                    // Create a second kernel for this step that sums the values.

8323
                    program = cl.createProgram(OpenCLKernelSources::customIntegrator, defines);
8324
                    kernel = cl::Kernel(program, useDouble ? "computeDoubleSum" : "computeFloatSum");
8325
8326
                    kernels[step].push_back(kernel);
                    index = 0;
peastman's avatar
peastman committed
8327
8328
                    kernel.setArg<cl::Buffer>(index++, sumBuffer.getDeviceBuffer());
                    kernel.setArg<cl::Buffer>(index++, summedValue.getDeviceBuffer());
peastman's avatar
peastman committed
8329
                    kernel.setArg<cl_int>(index++, numAtoms);
8330
                }
8331
            }
8332
8333
8334
8335
8336
8337
8338
8339
            else if (stepType[step] == CustomIntegrator::ConstrainPositions) {
                // Apply position constraints.

                cl::Program program = cl.createProgram(OpenCLKernelSources::customIntegrator, defines);
                cl::Kernel kernel = cl::Kernel(program, "applyPositionDeltas");
                kernels[step].push_back(kernel);
                int index = 0;
                kernel.setArg<cl::Buffer>(index++, cl.getPosq().getDeviceBuffer());
8340
                setPosqCorrectionArg(cl, kernel, index++);
8341
8342
                kernel.setArg<cl::Buffer>(index++, integration.getPosDelta().getDeviceBuffer());
            }
8343
        }
8344
        
8345
8346
8347
        // Initialize the random number generator.
        
        int maxUniformRandoms = 1;
peastman's avatar
peastman committed
8348
8349
        for (int required : requiredUniform)
            maxUniformRandoms = max(maxUniformRandoms, required);
peastman's avatar
peastman committed
8350
8351
8352
        uniformRandoms.initialize<mm_float4>(cl, maxUniformRandoms, "uniformRandoms");
        randomSeed.initialize<mm_int4>(cl, cl.getNumThreadBlocks()*OpenCLContext::ThreadBlockSize, "randomSeed");
        vector<mm_int4> seed(randomSeed.getSize());
8353
        int rseed = integrator.getRandomNumberSeed();
8354
        // A random seed of 0 means use a unique one
8355
8356
8357
        if (rseed == 0)
            rseed = osrngseed();
        unsigned int r = (unsigned int) (rseed+1);
peastman's avatar
peastman committed
8358
8359
8360
8361
8362
        for (auto& s : seed) {
            s.x = r = (1664525*r + 1013904223) & 0xFFFFFFFF;
            s.y = r = (1664525*r + 1013904223) & 0xFFFFFFFF;
            s.z = r = (1664525*r + 1013904223) & 0xFFFFFFFF;
            s.w = r = (1664525*r + 1013904223) & 0xFFFFFFFF;
8363
        }
peastman's avatar
peastman committed
8364
        randomSeed.upload(seed);
8365
8366
        cl::Program randomProgram = cl.createProgram(OpenCLKernelSources::customIntegrator, defines);
        randomKernel = cl::Kernel(randomProgram, "generateRandomNumbers");
8367
        randomKernel.setArg<cl_int>(0, maxUniformRandoms);
peastman's avatar
peastman committed
8368
8369
        randomKernel.setArg<cl::Buffer>(1, uniformRandoms.getDeviceBuffer());
        randomKernel.setArg<cl::Buffer>(2, randomSeed.getDeviceBuffer());
8370
        
8371
8372
8373
        // Create the kernel for computing kinetic energy.

        stringstream computeKE;
8374
        for (int i = 0; i < perDofValues.size(); i++)
8375
            computeKE << tempType<<" perDof"<<cl.intToString(i)<<" = convert_"<<tempType<<"(perDofValues"<<cl.intToString(i)<<"[index].xyz);\n";
8376
        Lepton::ParsedExpression keExpression = Lepton::Parser::parse(integrator.getKineticEnergyExpression()).optimize();
8377
        computeKE << createPerDofComputation("", keExpression, integrator, "f", "", functionList, functionNames);
8378
8379
8380
        map<string, string> replacements;
        replacements["COMPUTE_STEP"] = computeKE.str();
        stringstream args;
8381
8382
8383
        for (int i = 0; i < perDofValues.size(); i++) {
            string valueName = "perDofValues"+cl.intToString(i);
            args << ", __global " << perDofType << "* restrict " << valueName;
8384
        }
8385
8386
        for (int i = 0; i < (int) tableTypes.size(); i++)
            args << ", __global const " << tableTypes[i]<< "* restrict table" << i;
8387
8388
8389
        replacements["PARAMETER_ARGUMENTS"] = args.str();
        if (defines.find("LOAD_POS_AS_DELTA") != defines.end())
            defines.erase("LOAD_POS_AS_DELTA");
Peter Eastman's avatar
Peter Eastman committed
8390
        cl::Program program = cl.createProgram(cl.replaceStrings(OpenCLKernelSources::customIntegratorPerDof, replacements), defines);
8391
        kineticEnergyKernel = cl::Kernel(program, "computePerDof");
Peter Eastman's avatar
Peter Eastman committed
8392
        int index = 0;
8393
8394
8395
8396
8397
8398
        kineticEnergyKernel.setArg<cl::Buffer>(index++, cl.getPosq().getDeviceBuffer());
        setPosqCorrectionArg(cl, kineticEnergyKernel, index++);
        kineticEnergyKernel.setArg<cl::Buffer>(index++, integration.getPosDelta().getDeviceBuffer());
        kineticEnergyKernel.setArg<cl::Buffer>(index++, cl.getVelm().getDeviceBuffer());
        kineticEnergyKernel.setArg<cl::Buffer>(index++, cl.getForce().getDeviceBuffer());
        kineticEnergyKernel.setArg<cl::Buffer>(index++, integration.getStepSize().getDeviceBuffer());
peastman's avatar
peastman committed
8399
8400
        kineticEnergyKernel.setArg<cl::Buffer>(index++, globalValues.getDeviceBuffer());
        kineticEnergyKernel.setArg<cl::Buffer>(index++, sumBuffer.getDeviceBuffer());
8401
        index += 2;
peastman's avatar
peastman committed
8402
        kineticEnergyKernel.setArg<cl::Buffer>(index++, uniformRandoms.getDeviceBuffer());
Peter Eastman's avatar
Peter Eastman committed
8403
        if (cl.getUseDoublePrecision() || cl.getUseMixedPrecision())
Peter Eastman's avatar
Peter Eastman committed
8404
8405
8406
            kineticEnergyKernel.setArg<cl_double>(index++, 0.0);
        else
            kineticEnergyKernel.setArg<cl_float>(index++, 0.0f);
peastman's avatar
peastman committed
8407
        kineticEnergyKernel.setArg<cl::Buffer>(index++, perDofEnergyParamDerivs.getDeviceBuffer());
8408
        for (auto& array : perDofValues)
8409
            kineticEnergyKernel.setArg<cl::Buffer>(index++, array.getDeviceBuffer());
peastman's avatar
peastman committed
8410
8411
        for (auto& array : tabulatedFunctions)
            kineticEnergyKernel.setArg<cl::Buffer>(index++, array.getDeviceBuffer());
8412
8413
8414
8415
8416
8417
8418
        keNeedsForce = usesVariable(keExpression, "f");

        // Create a second kernel to sum the values.

        program = cl.createProgram(OpenCLKernelSources::customIntegrator, defines);
        sumKineticEnergyKernel = cl::Kernel(program, useDouble ? "computeDoubleSum" : "computeFloatSum");
        index = 0;
peastman's avatar
peastman committed
8419
8420
        sumKineticEnergyKernel.setArg<cl::Buffer>(index++, sumBuffer.getDeviceBuffer());
        sumKineticEnergyKernel.setArg<cl::Buffer>(index++, summedValue.getDeviceBuffer());
peastman's avatar
peastman committed
8421
        sumKineticEnergyKernel.setArg<cl_int>(index++, numAtoms);
8422
8423
8424
8425
8426

        // Delete the custom functions.

        for (auto& function : functions)
            delete function.second;
8427
    }
8428

8429
    // Make sure all values (variables, parameters, etc.) are up to date.
8430
    
8431
8432
8433
8434
8435
8436
8437
8438
8439
    for (int i = 0; i < perDofValues.size(); i++) {
        if (!deviceValuesAreCurrent[i]) {
            if (useDouble)
                perDofValues[i].upload(localPerDofValuesDouble[i]);
            else
                perDofValues[i].upload(localPerDofValuesFloat[i]);
            deviceValuesAreCurrent[i] = true;
        }
        localValuesAreCurrent[i] = false;
8440
8441
    }
    double stepSize = integrator.getStepSize();
8442
    recordGlobalValue(stepSize, GlobalTarget(DT, dtVariableIndex), integrator);
8443
8444
    for (int i = 0; i < (int) parameterNames.size(); i++) {
        double value = context.getParameter(parameterNames[i]);
peastman's avatar
peastman committed
8445
8446
        if (value != localGlobalValues[parameterVariableIndex[i]]) {
            localGlobalValues[parameterVariableIndex[i]] = value;
8447
            deviceGlobalsAreCurrent = false;
8448
8449
        }
    }
8450
}
8451

8452
8453
8454
8455
8456
8457
8458
8459
8460
8461
8462
8463
8464
8465
ExpressionTreeNode OpenCLIntegrateCustomStepKernel::replaceDerivFunctions(const ExpressionTreeNode& node, ContextImpl& context) {
    // This is called recursively to identify calls to the deriv() function inside global expressions,
    // and replace them with a custom function that returns the correct value.
    
    const Operation& op = node.getOperation();
    if (op.getId() == Operation::CUSTOM && op.getName() == "deriv") {
        string param = node.getChildren()[1].getOperation().getName();
        if (context.getParameters().find(param) == context.getParameters().end())
            throw OpenMMException("The second argument to deriv() must be a context parameter");
        needsEnergyParamDerivs = true;
        return ExpressionTreeNode(new Operation::Custom("deriv", new DerivFunction(energyParamDerivs, param)));
    }
    else {
        vector<ExpressionTreeNode> children;
peastman's avatar
peastman committed
8466
8467
        for (auto& child : node.getChildren())
            children.push_back(replaceDerivFunctions(child, context));
8468
8469
8470
8471
8472
8473
8474
8475
8476
8477
8478
8479
8480
8481
8482
8483
8484
8485
8486
8487
        return ExpressionTreeNode(op.clone(), children);
    }
}

void OpenCLIntegrateCustomStepKernel::findExpressionsForDerivs(const ExpressionTreeNode& node, vector<pair<ExpressionTreeNode, string> >& variableNodes) {
    // This is called recursively to identify calls to the deriv() function inside per-DOF expressions,
    // and record the code to replace them with.
    
    const Operation& op = node.getOperation();
    if (op.getId() == Operation::CUSTOM && op.getName() == "deriv") {
        string param = node.getChildren()[1].getOperation().getName();
        int index;
        for (index = 0; index < perDofEnergyParamDerivNames.size() && param != perDofEnergyParamDerivNames[index]; index++)
            ;
        if (index == perDofEnergyParamDerivNames.size())
            perDofEnergyParamDerivNames.push_back(param);
        variableNodes.push_back(make_pair(node, "energyParamDerivs["+cl.intToString(index)+"]"));
        needsEnergyParamDerivs = true;
    }
    else {
peastman's avatar
peastman committed
8488
8489
        for (auto& child : node.getChildren())
            findExpressionsForDerivs(child, variableNodes);
8490
8491
8492
    }
}

8493
8494
8495
8496
8497
void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid) {
    prepareForComputation(context, integrator, forcesAreValid);
    OpenCLIntegrationUtilities& integration = cl.getIntegrationUtilities();
    int numAtoms = cl.getNumAtoms();
    int numSteps = integrator.getNumComputations();
8498
8499
    if (!forcesAreValid)
        savedEnergy.clear();
8500
    
8501
8502
    // Loop over computation steps in the integrator and execute them.

8503
8504
    for (int step = 0; step < numSteps; ) {
        int nextStep = step+1;
8505
        int forceGroups = forceGroupFlags[step];
8506
        int lastForceGroups = context.getLastForceGroups();
8507
8508
8509
        bool haveForces = (!needsForces[step] || (forcesAreValid && lastForceGroups == forceGroups));
        bool haveEnergy = (!needsEnergy[step] || savedEnergy.find(forceGroups) != savedEnergy.end());
        if (!haveForces || !haveEnergy) {
Peter Eastman's avatar
Peter Eastman committed
8510
8511
8512
8513
8514
            if (forcesAreValid) {
                if (savedForces.find(lastForceGroups) != savedForces.end() && validSavedForces.find(lastForceGroups) == validSavedForces.end()) {
                    // The forces are still valid.  We just need a different force group right now.  Save the old
                    // forces in case we need them again.

peastman's avatar
peastman committed
8515
                    cl.getForce().copyTo(savedForces[lastForceGroups]);
Peter Eastman's avatar
Peter Eastman committed
8516
8517
                    validSavedForces.insert(lastForceGroups);
                }
8518
8519
8520
8521
            }
            else
                validSavedForces.clear();
            
8522
8523
8524
            // Recompute forces and/or energy.  Figure out what is actually needed
            // between now and the next time they get invalidated again.
            
8525
8526
            bool computeForce = (needsForces[step] || computeBothForceAndEnergy[step]);
            bool computeEnergy = (needsEnergy[step] || computeBothForceAndEnergy[step]);
8527
            if (!computeEnergy && validSavedForces.find(forceGroups) != validSavedForces.end()) {
8528
8529
                // We can just restore the forces we saved earlier.
                
peastman's avatar
peastman committed
8530
                savedForces[forceGroups].copyTo(cl.getForce());
8531
                context.getLastForceGroups() = forceGroups;
8532
8533
8534
            }
            else {
                recordChangedParameters(context);
8535
8536
                energy = context.calcForcesAndEnergy(computeForce, computeEnergy, forceGroups);
                savedEnergy[forceGroups] = energy;
8537
8538
8539
                if (needsEnergyParamDerivs) {
                    context.getEnergyParameterDerivatives(energyParamDerivs);
                    if (perDofEnergyParamDerivNames.size() > 0) {
peastman's avatar
peastman committed
8540
8541
8542
                        for (int i = 0; i < perDofEnergyParamDerivNames.size(); i++)
                            localPerDofEnergyParamDerivs[i] = energyParamDerivs[perDofEnergyParamDerivNames[i]];
                        perDofEnergyParamDerivs.upload(localPerDofEnergyParamDerivs, true, true);
8543
8544
8545
                    }
                }
                forcesAreValid = true;
8546
            }
8547
        }
8548
8549
        if (needsEnergy[step])
            energy = savedEnergy[forceGroups];
8550
8551
8552
        if (needsGlobals[step] && !deviceGlobalsAreCurrent) {
            // Upload the global values to the device.
            
peastman's avatar
peastman committed
8553
            globalValues.upload(localGlobalValues, true, true);
8554
            deviceGlobalsAreCurrent = true;
8555
        }
8556
        bool stepInvalidatesForces = invalidatesForces[step];
8557
8558
8559
        if (stepType[step] == CustomIntegrator::ComputePerDof && !merged[step]) {
            kernels[step][0].setArg<cl_uint>(9, integration.prepareRandomNumbers(requiredGaussian[step]));
            kernels[step][0].setArg<cl::Buffer>(8, integration.getRandom().getDeviceBuffer());
peastman's avatar
peastman committed
8560
            kernels[step][0].setArg<cl::Buffer>(10, uniformRandoms.getDeviceBuffer());
Peter Eastman's avatar
Peter Eastman committed
8561
            if (cl.getUseDoublePrecision() || cl.getUseMixedPrecision())
8562
                kernels[step][0].setArg<cl_double>(11, energy);
Peter Eastman's avatar
Peter Eastman committed
8563
            else
8564
8565
8566
                kernels[step][0].setArg<cl_float>(11, (cl_float) energy);
            if (requiredUniform[step] > 0)
                cl.executeKernel(randomKernel, numAtoms);
peastman's avatar
peastman committed
8567
            cl.executeKernel(kernels[step][0], numAtoms, 128);
8568
8569
8570
8571
8572
        }
        else if (stepType[step] == CustomIntegrator::ComputeGlobal) {
            expressionSet.setVariable(uniformVariableIndex, SimTKOpenMMUtilities::getUniformlyDistributedRandomNumber());
            expressionSet.setVariable(gaussianVariableIndex, SimTKOpenMMUtilities::getNormallyDistributedRandomNumber());
            expressionSet.setVariable(stepEnergyVariableIndex[step], energy);
8573
            recordGlobalValue(globalExpressions[step][0].evaluate(), stepTarget[step], integrator);
8574
8575
8576
8577
        }
        else if (stepType[step] == CustomIntegrator::ComputeSum) {
            kernels[step][0].setArg<cl_uint>(9, integration.prepareRandomNumbers(requiredGaussian[step]));
            kernels[step][0].setArg<cl::Buffer>(8, integration.getRandom().getDeviceBuffer());
peastman's avatar
peastman committed
8578
            kernels[step][0].setArg<cl::Buffer>(10, uniformRandoms.getDeviceBuffer());
Peter Eastman's avatar
Peter Eastman committed
8579
            if (cl.getUseDoublePrecision() || cl.getUseMixedPrecision())
8580
                kernels[step][0].setArg<cl_double>(11, energy);
Peter Eastman's avatar
Peter Eastman committed
8581
            else
8582
8583
                kernels[step][0].setArg<cl_float>(11, (cl_float) energy);
            if (requiredUniform[step] > 0)
8584
                cl.executeKernel(randomKernel, numAtoms);
peastman's avatar
peastman committed
8585
            cl.clearBuffer(sumBuffer);
peastman's avatar
peastman committed
8586
            cl.executeKernel(kernels[step][0], numAtoms, 128);
8587
            cl.executeKernel(kernels[step][1], sumWorkGroupSize, sumWorkGroupSize);
8588
8589
            if (cl.getUseDoublePrecision() || cl.getUseMixedPrecision()) {
                double value;
peastman's avatar
peastman committed
8590
                summedValue.download(&value);
8591
                recordGlobalValue(value, stepTarget[step], integrator);
8592
8593
8594
            }
            else {
                float value;
peastman's avatar
peastman committed
8595
                summedValue.download(&value);
8596
                recordGlobalValue(value, stepTarget[step], integrator);
8597
            }
8598
        }
8599
        else if (stepType[step] == CustomIntegrator::UpdateContextState) {
8600
            recordChangedParameters(context);
8601
            stepInvalidatesForces = context.updateContextState();
8602
        }
8603
        else if (stepType[step] == CustomIntegrator::ConstrainPositions) {
peastman's avatar
peastman committed
8604
8605
8606
8607
            if (hasAnyConstraints) {
                cl.getIntegrationUtilities().applyConstraints(integrator.getConstraintTolerance());
                cl.executeKernel(kernels[step][0], numAtoms);
            }
8608
            cl.getIntegrationUtilities().computeVirtualSites();
8609
        }
8610
        else if (stepType[step] == CustomIntegrator::ConstrainVelocities) {
8611
8612
            cl.getIntegrationUtilities().applyVelocityConstraints(integrator.getConstraintTolerance());
        }
8613
        else if (stepType[step] == CustomIntegrator::IfBlockStart) {
8614
8615
8616
            if (!evaluateCondition(step))
                nextStep = blockEnd[step]+1;
        }
8617
        else if (stepType[step] == CustomIntegrator::WhileBlockStart) {
8618
8619
8620
            if (!evaluateCondition(step))
                nextStep = blockEnd[step]+1;
        }
8621
        else if (stepType[step] == CustomIntegrator::BlockEnd) {
8622
8623
8624
            if (blockEnd[step] != -1)
                nextStep = blockEnd[step]; // Return to the start of a while block.
        }
8625
        if (stepInvalidatesForces) {
8626
            forcesAreValid = false;
8627
8628
            savedEnergy.clear();
        }
8629
        step = nextStep;
8630
    }
8631
    recordChangedParameters(context);
8632
8633
8634

    // Update the time and step count.

8635
    cl.setTime(cl.getTime()+integrator.getStepSize());
8636
    cl.setStepCount(cl.getStepCount()+1);
8637
    cl.reorderAtoms();
8638
8639
8640
8641
    if (cl.getAtomsWereReordered()) {
        forcesAreValid = false;
        validSavedForces.clear();
    }
8642
8643
8644
8645
8646
8647
    
    // Reduce UI lag.
    
#ifdef WIN32
    cl.getQueue().flush();
#endif
8648
8649
}

8650
8651
8652
8653
8654
8655
8656
8657
8658
8659
8660
8661
8662
8663
8664
8665
8666
8667
8668
8669
8670
8671
8672
bool OpenCLIntegrateCustomStepKernel::evaluateCondition(int step) {
    expressionSet.setVariable(uniformVariableIndex, SimTKOpenMMUtilities::getUniformlyDistributedRandomNumber());
    expressionSet.setVariable(gaussianVariableIndex, SimTKOpenMMUtilities::getNormallyDistributedRandomNumber());
    expressionSet.setVariable(stepEnergyVariableIndex[step], energy);
    double lhs = globalExpressions[step][0].evaluate();
    double rhs = globalExpressions[step][1].evaluate();
    switch (comparisons[step]) {
        case CustomIntegratorUtilities::EQUAL:
            return (lhs == rhs);
        case CustomIntegratorUtilities::LESS_THAN:
            return (lhs < rhs);
        case CustomIntegratorUtilities::GREATER_THAN:
            return (lhs > rhs);
        case CustomIntegratorUtilities::NOT_EQUAL:
            return (lhs != rhs);
        case CustomIntegratorUtilities::LESS_THAN_OR_EQUAL:
            return (lhs <= rhs);
        case CustomIntegratorUtilities::GREATER_THAN_OR_EQUAL:
            return (lhs >= rhs);
    }
    throw OpenMMException("Invalid comparison operator");
}

8673
8674
8675
8676
8677
8678
8679
8680
8681
double OpenCLIntegrateCustomStepKernel::computeKineticEnergy(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid) {
    prepareForComputation(context, integrator, forcesAreValid);
    if (keNeedsForce && !forcesAreValid) {
        // Compute the force.  We want to then mark that forces are valid, which means also computing
        // potential energy if any steps will expect it to be valid too.
        
        bool willNeedEnergy = false;
        for (int i = 0; i < integrator.getNumComputations(); i++)
            willNeedEnergy |= needsEnergy[i];
Peter Eastman's avatar
Peter Eastman committed
8682
        energy = context.calcForcesAndEnergy(true, willNeedEnergy, -1);
8683
8684
        forcesAreValid = true;
    }
peastman's avatar
peastman committed
8685
    cl.clearBuffer(sumBuffer);
8686
8687
    kineticEnergyKernel.setArg<cl::Buffer>(8, cl.getIntegrationUtilities().getRandom().getDeviceBuffer());
    kineticEnergyKernel.setArg<cl_uint>(9, 0);
8688
    cl.executeKernel(kineticEnergyKernel, cl.getNumAtoms());
8689
    cl.executeKernel(sumKineticEnergyKernel, sumWorkGroupSize, sumWorkGroupSize);
8690
8691
    if (cl.getUseDoublePrecision() || cl.getUseMixedPrecision()) {
        double ke;
peastman's avatar
peastman committed
8692
        summedValue.download(&ke);
8693
8694
8695
8696
        return ke;
    }
    else {
        float ke;
peastman's avatar
peastman committed
8697
        summedValue.download(&ke);
8698
8699
8700
8701
        return ke;
    }
}

8702
void OpenCLIntegrateCustomStepKernel::recordGlobalValue(double value, GlobalTarget target, CustomIntegrator& integrator) {
8703
8704
    switch (target.type) {
        case DT:
peastman's avatar
peastman committed
8705
            if (value != localGlobalValues[dtVariableIndex])
8706
                deviceGlobalsAreCurrent = false;
8707
            expressionSet.setVariable(dtVariableIndex, value);
peastman's avatar
peastman committed
8708
            localGlobalValues[dtVariableIndex] = value;
8709
            cl.getIntegrationUtilities().setNextStepSize(value);
8710
            integrator.setStepSize(value);
8711
8712
8713
8714
            break;
        case VARIABLE:
        case PARAMETER:
            expressionSet.setVariable(target.variableIndex, value);
peastman's avatar
peastman committed
8715
            localGlobalValues[target.variableIndex] = value;
8716
8717
8718
8719
8720
            deviceGlobalsAreCurrent = false;
            break;
    }
}

8721
8722
8723
void OpenCLIntegrateCustomStepKernel::recordChangedParameters(ContextImpl& context) {
    if (!modifiesParameters)
        return;
8724
8725
    for (int i = 0; i < (int) parameterNames.size(); i++) {
        double value = context.getParameter(parameterNames[i]);
peastman's avatar
peastman committed
8726
8727
        if (value != localGlobalValues[parameterVariableIndex[i]])
            context.setParameter(parameterNames[i], localGlobalValues[parameterVariableIndex[i]]);
8728
8729
8730
    }
}

8731
void OpenCLIntegrateCustomStepKernel::getGlobalVariables(ContextImpl& context, vector<double>& values) const {
peastman's avatar
peastman committed
8732
    if (!globalValues.isInitialized()) {
8733
8734
8735
        // The data structures haven't been created yet, so just return the list of values that was given earlier.
        
        values = initialGlobalVariables;
peastman's avatar
peastman committed
8736
        return;
8737
    }
8738
8739
    values.resize(numGlobalVariables);
    for (int i = 0; i < numGlobalVariables; i++)
peastman's avatar
peastman committed
8740
        values[i] = localGlobalValues[globalVariableIndex[i]];
8741
8742
8743
}

void OpenCLIntegrateCustomStepKernel::setGlobalVariables(ContextImpl& context, const vector<double>& values) {
8744
8745
    if (numGlobalVariables == 0)
        return;
peastman's avatar
peastman committed
8746
    if (!globalValues.isInitialized()) {
8747
8748
8749
8750
8751
8752
        // The data structures haven't been created yet, so just store the list of values.
        
        initialGlobalVariables = values;
        return;
    }
    for (int i = 0; i < numGlobalVariables; i++) {
peastman's avatar
peastman committed
8753
        localGlobalValues[globalVariableIndex[i]] = values[i];
8754
        expressionSet.setVariable(globalVariableIndex[i], values[i]);
8755
    }
8756
    deviceGlobalsAreCurrent = false;
8757
8758
8759
}

void OpenCLIntegrateCustomStepKernel::getPerDofVariable(ContextImpl& context, int variable, vector<Vec3>& values) const {
8760
    values.resize(perDofValues[variable].getSize());
8761
8762
    const vector<int>& order = cl.getAtomIndex();
    if (cl.getUseDoublePrecision() || cl.getUseMixedPrecision()) {
8763
8764
8765
8766
8767
8768
8769
8770
        if (!localValuesAreCurrent[variable]) {
            perDofValues[variable].download(localPerDofValuesDouble[variable]);
            localValuesAreCurrent[variable] = true;
        }
        for (int i = 0; i < (int) values.size(); i++) {
            values[order[i]][0] = localPerDofValuesDouble[variable][i].x;
            values[order[i]][1] = localPerDofValuesDouble[variable][i].y;
            values[order[i]][2] = localPerDofValuesDouble[variable][i].z;
8771
8772
8773
        }
    }
    else {
8774
8775
8776
8777
8778
8779
8780
8781
        if (!localValuesAreCurrent[variable]) {
            perDofValues[variable].download(localPerDofValuesFloat[variable]);
            localValuesAreCurrent[variable] = true;
        }
        for (int i = 0; i < (int) values.size(); i++) {
            values[order[i]][0] = localPerDofValuesFloat[variable][i].x;
            values[order[i]][1] = localPerDofValuesFloat[variable][i].y;
            values[order[i]][2] = localPerDofValuesFloat[variable][i].z;
8782
8783
        }
    }
8784
8785
8786
}

void OpenCLIntegrateCustomStepKernel::setPerDofVariable(ContextImpl& context, int variable, const vector<Vec3>& values) {
8787
    const vector<int>& order = cl.getAtomIndex();
8788
8789
    localValuesAreCurrent[variable] = true;
    deviceValuesAreCurrent[variable] = false;
8790
    if (cl.getUseDoublePrecision() || cl.getUseMixedPrecision()) {
8791
        localPerDofValuesDouble[variable].resize(values.size());
8792
        for (int i = 0; i < (int) values.size(); i++)
8793
            localPerDofValuesDouble[variable][i] = mm_double4(values[order[i]][0], values[order[i]][1], values[order[i]][2], 0);
8794
8795
    }
    else {
8796
        localPerDofValuesFloat[variable].resize(values.size());
8797
        for (int i = 0; i < (int) values.size(); i++)
8798
            localPerDofValuesFloat[variable][i] = mm_float4(values[order[i]][0], values[order[i]][1], values[order[i]][2], 0);
8799
8800
8801
    }
}

8802
8803
8804
void OpenCLApplyAndersenThermostatKernel::initialize(const System& system, const AndersenThermostat& thermostat) {
    randomSeed = thermostat.getRandomNumberSeed();
    map<string, string> defines;
8805
    defines["NUM_ATOMS"] = cl.intToString(cl.getNumAtoms());
8806
    cl::Program program = cl.createProgram(OpenCLKernelSources::andersenThermostat, defines);
8807
    kernel = cl::Kernel(program, "applyAndersenThermostat");
Peter Eastman's avatar
Peter Eastman committed
8808
    cl.getIntegrationUtilities().initRandomNumberGenerator(randomSeed);
8809
8810
8811
8812

    // Create the arrays with the group definitions.

    vector<vector<int> > groups = AndersenThermostatImpl::calcParticleGroups(system);
peastman's avatar
peastman committed
8813
8814
    atomGroups.initialize<int>(cl, cl.getNumAtoms(), "atomGroups");
    vector<int> atoms(atomGroups.getSize());
8815
8816
8817
8818
    for (int i = 0; i < (int) groups.size(); i++) {
        for (int j = 0; j < (int) groups[i].size(); j++)
            atoms[groups[i][j]] = i;
    }
peastman's avatar
peastman committed
8819
    atomGroups.upload(atoms);
8820
8821
8822
8823
8824
8825
8826
8827
}

void OpenCLApplyAndersenThermostatKernel::execute(ContextImpl& context) {
    if (!hasInitializedKernels) {
        hasInitializedKernels = true;
        kernel.setArg<cl::Buffer>(2, cl.getVelm().getDeviceBuffer());
        kernel.setArg<cl::Buffer>(3, cl.getIntegrationUtilities().getStepSize().getDeviceBuffer());
        kernel.setArg<cl::Buffer>(4, cl.getIntegrationUtilities().getRandom().getDeviceBuffer());
peastman's avatar
peastman committed
8828
        kernel.setArg<cl::Buffer>(6, atomGroups.getDeviceBuffer());
8829
8830
8831
8832
8833
8834
    }
    kernel.setArg<cl_float>(0, (cl_float) context.getParameter(AndersenThermostat::CollisionFrequency()));
    kernel.setArg<cl_float>(1, (cl_float) (BOLTZ*context.getParameter(AndersenThermostat::Temperature())));
    kernel.setArg<cl_uint>(5, cl.getIntegrationUtilities().prepareRandomNumbers(cl.getPaddedNumAtoms()));
    cl.executeKernel(kernel, cl.getNumAtoms());
}
8835
8836
8837
8838
8839
8840
8841
8842
8843
8844
8845
8846
8847
8848
8849
8850
8851
8852
8853
8854
8855
8856
8857
void OpenCLNoseHooverChainKernel::initialize() {

    bool useDouble = cl.getUseDoublePrecision() || cl.getUseMixedPrecision();

    map<string, string> defines;
    defines["BEGIN_YS_LOOP"] = "const real arr[1] = {1.0}; for(int i=0;i<1;++i) { const real ys = arr[i];";
    defines["END_YS_LOOP"] = "}";
    cl::Program program = cl.createProgram(OpenCLKernelSources::noseHooverChain, defines);
    propagateKernels[1] = cl::Kernel(program, "propagateNoseHooverChain");
    defines["BEGIN_YS_LOOP"] = "const real arr[3] = {0.828981543588751, -0.657963087177502, 0.828981543588751}; for(int i=0;i<3;++i) { const real ys = arr[i];";
    program = cl.createProgram(OpenCLKernelSources::noseHooverChain, defines);
    propagateKernels[3] = cl::Kernel(program, "propagateNoseHooverChain");
    defines["BEGIN_YS_LOOP"] = "const real arr[5] = {0.2967324292201065, 0.2967324292201065, -0.186929716880426, 0.2967324292201065, 0.2967324292201065}; for(int i=0;i<5;++i) { const real ys = arr[i];";
    program = cl.createProgram(OpenCLKernelSources::noseHooverChain, defines);
    propagateKernels[5] = cl::Kernel(program, "propagateNoseHooverChain");
    program = cl.createProgram(OpenCLKernelSources::noseHooverChain, defines);
    reduceEnergyKernel = cl::Kernel(program, "reduceEnergyPair");

    computeHeatBathEnergyKernel = cl::Kernel(program, "computeHeatBathEnergy");
    computeAtomsKineticEnergyKernel = cl::Kernel(program, "computeAtomsKineticEnergy");
    computePairsKineticEnergyKernel = cl::Kernel(program, "computePairsKineticEnergy");
    scaleAtomsVelocitiesKernel = cl::Kernel(program, "scaleAtomsVelocities");
    scalePairsVelocitiesKernel = cl::Kernel(program, "scalePairsVelocities");
8858
8859
8860
8861
8862
8863
    int energyBufferSize = cl.getEnergyBuffer().getSize();
    if (cl.getUseDoublePrecision() || cl.getUseMixedPrecision()) {
        energyBuffer.initialize<mm_double2>(cl, energyBufferSize, "energyBuffer");
    } else {
        energyBuffer.initialize<mm_float2>(cl, energyBufferSize, "energyBuffer");
    }
8864
8865
8866
8867
8868
8869
}

std::pair<double, double> OpenCLNoseHooverChainKernel::propagateChain(ContextImpl& context, const NoseHooverChain &nhc, std::pair<double, double> kineticEnergies, double timeStep) {

    bool useDouble = cl.getUseDoublePrecision() || cl.getUseMixedPrecision();

8870
    int chainID = nhc.getChainID();
8871
8872
    int nAtoms = nhc.getThermostatedAtoms().size();
    int nPairs = nhc.getThermostatedPairs().size();
8873
8874
8875
8876
8877
8878
8879
8880
    int chainLength = nhc.getChainLength();
    int numYS = nhc.getNumYoshidaSuzukiTimeSteps();
    int numMTS = nhc.getNumMultiTimeSteps();
    int numDOFs = nhc.getNumDegreesOfFreedom();
    double temperature = nhc.getTemperature();
    double frequency = nhc.getCollisionFrequency();
    double relativeTemperature = nhc.getRelativeTemperature();
    double relativeFrequency = nhc.getRelativeCollisionFrequency();
8881
8882
8883
8884
8885
8886
8887
8888
8889
8890

    if (numYS != 1 && numYS != 3 && numYS != 5) {
        throw OpenMMException("Number of Yoshida Suzuki time steps has to be 1, 3, or 5.");
    }

    auto & chainState = cl.getIntegrationUtilities().getNoseHooverChainState();

    if (!scaleFactorBuffer.isInitialized() ||scaleFactorBuffer.getSize() == 0) {
        if(useDouble){
            std::vector<mm_double2> zeros{{0,0}};
8891
8892
8893
8894
8895
            if (scaleFactorBuffer.isInitialized()) {
                scaleFactorBuffer.resize(1);
            } else {
                scaleFactorBuffer.initialize<mm_double2>(cl, 1, "scaleFactorBuffer");
            }
8896
8897
8898
            scaleFactorBuffer.upload(zeros);
        } else {
            std::vector<mm_float2> zeros{{0,0}};
8899
8900
8901
8902
8903
            if (scaleFactorBuffer.isInitialized()) {
                scaleFactorBuffer.resize(1);
            } else {
                scaleFactorBuffer.initialize<mm_float2>(cl, 1, "scaleFactorBuffer");
            }
8904
8905
8906
8907
8908
8909
            scaleFactorBuffer.upload(zeros);
        }
    }
    if (!chainForces.isInitialized() || !chainMasses.isInitialized() ){
        if(useDouble){
            std::vector<cl_double> zeros(chainLength,0);
8910
8911
8912
8913
8914
8915
8916
            if (chainForces.isInitialized()) {
                chainMasses.resize(chainLength);
                chainForces.resize(chainLength);
            } else {
                chainMasses.initialize<cl_double>(cl, chainLength, "chainMasses");
                chainForces.initialize<cl_double>(cl, chainLength, "chainForces");
            }
8917
8918
8919
8920
            chainMasses.upload(zeros);
            chainForces.upload(zeros);
        } else {
            std::vector<cl_float> zeros(chainLength,0);
8921
8922
8923
8924
8925
8926
8927
            if (chainForces.isInitialized()) {
                chainMasses.resize(chainLength);
                chainForces.resize(chainLength);
            } else {
                chainMasses.initialize<cl_float>(cl, chainLength, "chainMasses");
                chainForces.initialize<cl_float>(cl, chainLength, "chainForces");
            }
8928
8929
8930
8931
8932
8933
8934
8935
8936
8937
8938
8939
8940
8941
            chainMasses.upload(zeros);
            chainForces.upload(zeros);
        }
    }
    if (chainForces.getSize() < chainLength) chainMasses.resize(chainLength);
    if (chainMasses.getSize() < chainLength) chainMasses.resize(chainLength);

    float timeStepFloat = (float) timeStep;
    // N.B. We ignore the incoming kineticEnergy and grab it from the device buffer instead
    if (nAtoms) {
        if (!chainState.count(2*chainID))  chainState[2*chainID] = OpenCLArray();
        if (chainState.at(2*chainID).getSize() != chainLength) {
            // We need to upload the OpenCL array
            if(useDouble){
8942
8943
8944
8945
8946
                if (chainState.at(2*chainID).isInitialized()) {
                    chainState.at(2*chainID).resize(chainLength);
                } else {
                    chainState.at(2*chainID).initialize<mm_double2>(cl, chainLength, "chainState" + std::to_string(2*chainID));
                }
8947
8948
8949
                std::vector<mm_double2> zeros(chainLength, mm_double2(0.0, 0.0));
                chainState.at(2*chainID).upload(zeros.data());
            } else {
8950
8951
8952
8953
8954
                if (chainState.at(2*chainID).isInitialized()) {
                    chainState.at(2*chainID).resize(chainLength);
                } else {
                    chainState.at(2*chainID).initialize<mm_float2>(cl, chainLength, "chainState" + std::to_string(2*chainID));
                }
8955
8956
8957
8958
8959
8960
8961
8962
8963
8964
8965
8966
8967
8968
8969
8970
8971
8972
8973
8974
8975
8976
8977
8978
8979
8980
8981
8982
8983
8984
                std::vector<mm_float2> zeros(chainLength, mm_float2(0.0f, 0.0f));
                chainState.at(2*chainID).upload(zeros.data());
            }
        }
        int chainType = 0;
        double kT = BOLTZ * temperature;
        float kTfloat = (float) kT;
        float frequencyFloat = (float) frequency;
        propagateKernels[numYS].setArg<cl::Buffer>(0, chainState[2*chainID].getDeviceBuffer());
        propagateKernels[numYS].setArg<cl::Buffer>(1, kineticEnergyBuffer.getDeviceBuffer());
        propagateKernels[numYS].setArg<cl::Buffer>(2, scaleFactorBuffer.getDeviceBuffer());
        propagateKernels[numYS].setArg<cl::Buffer>(3, chainMasses.getDeviceBuffer());
        propagateKernels[numYS].setArg<cl::Buffer>(4, chainForces.getDeviceBuffer());
        propagateKernels[numYS].setArg<cl_int>(5, chainType);
        propagateKernels[numYS].setArg<cl_int>(6, chainLength);
        propagateKernels[numYS].setArg<cl_int>(7, numMTS);
        propagateKernels[numYS].setArg<cl_int>(8, numDOFs);
        propagateKernels[numYS].setArg<cl_float>(9, timeStepFloat);
        if (useDouble) 
            propagateKernels[numYS].setArg<cl_double>(10, kT);
        else
            propagateKernels[numYS].setArg<cl_float>(10, kTfloat);
        propagateKernels[numYS].setArg<cl_float>(11, frequencyFloat);
        cl.executeKernel(propagateKernels[numYS], 1, 1);
    }
    if (nPairs) {
        if (!chainState.count(2*chainID+1)) chainState[2*chainID+1] = OpenCLArray();
        if (chainState.at(2*chainID+1).getSize() != chainLength) {
            // We need to upload the OpenCL array
            if(useDouble){
8985
8986
8987
8988
8989
                if (chainState.at(2*chainID+1).isInitialized()) {
                    chainState.at(2*chainID+1).resize(chainLength);
                } else {
                    chainState.at(2*chainID+1).initialize<mm_double2>(cl, chainLength, "chainState" + std::to_string(2*chainID+1));
                }
8990
8991
8992
                std::vector<mm_double2> zeros(chainLength, mm_double2(0.0, 0.0));
                chainState.at(2*chainID+1).upload(zeros.data());
            } else {
8993
8994
8995
8996
8997
                if (chainState.at(2*chainID+1).isInitialized()) {
                    chainState.at(2*chainID+1).resize(chainLength);
                } else {
                    chainState.at(2*chainID+1).initialize<mm_float2>(cl, chainLength, "chainState" + std::to_string(2*chainID+1));
                }
8998
8999
9000
9001
9002
9003
9004
9005
9006
9007
9008
9009
9010
9011
9012
9013
9014
9015
9016
9017
9018
9019
9020
9021
9022
9023
9024
9025
9026
9027
9028
9029
9030
                std::vector<mm_float2> zeros(chainLength, mm_float2(0.0f, 0.0f));
                chainState.at(2*chainID+1).upload(zeros.data());
            }
        }
        int chainType = 1;
        double kT = BOLTZ * relativeTemperature;
        int ndf = 3*nPairs;
        float kTfloat = (float) kT;
        float frequencyFloat = (float) relativeFrequency;
        propagateKernels[numYS].setArg<cl::Buffer>(0, chainState[2*chainID+1].getDeviceBuffer());
        propagateKernels[numYS].setArg<cl::Buffer>(1, kineticEnergyBuffer.getDeviceBuffer());
        propagateKernels[numYS].setArg<cl::Buffer>(2, scaleFactorBuffer.getDeviceBuffer());
        propagateKernels[numYS].setArg<cl::Buffer>(3, chainMasses.getDeviceBuffer());
        propagateKernels[numYS].setArg<cl::Buffer>(4, chainForces.getDeviceBuffer());
        propagateKernels[numYS].setArg<cl_int>(5, chainType);
        propagateKernels[numYS].setArg<cl_int>(6, chainLength);
        propagateKernels[numYS].setArg<cl_int>(7, numMTS);
        propagateKernels[numYS].setArg<cl_int>(8, ndf);
        propagateKernels[numYS].setArg<cl_float>(9, timeStepFloat);
        if (useDouble) 
            propagateKernels[numYS].setArg<cl_double>(10, kT);
        else
            propagateKernels[numYS].setArg<cl_float>(10, kTfloat);
        propagateKernels[numYS].setArg<cl_float>(11, frequencyFloat);
        cl.executeKernel(propagateKernels[numYS], 1, 1);
    }
    return {0, 0};
}

double OpenCLNoseHooverChainKernel::computeHeatBathEnergy(ContextImpl& context, const NoseHooverChain &nhc) {

    bool useDouble = cl.getUseDoublePrecision() || cl.getUseMixedPrecision();

9031
9032
    int chainID = nhc.getChainID();
    int chainLength = nhc.getChainLength();
9033
9034
9035
9036
9037
9038
9039
9040
9041
9042
9043
9044
9045
9046
9047
9048
9049
9050
9051
9052
9053
9054
9055
9056
9057
9058
9059

    auto & chainState = cl.getIntegrationUtilities().getNoseHooverChainState();

    bool absChainIsValid = chainState.count(2*chainID) != 0 &&
                           chainState[2*chainID].isInitialized() &&
                           chainState[2*chainID].getSize() == chainLength;
    bool relChainIsValid = chainState.count(2*chainID+1) != 0 &&
                           chainState[2*chainID+1].isInitialized() &&
                           chainState[2*chainID+1].getSize() == chainLength;

    if (!absChainIsValid && !relChainIsValid) return 0.0;

    if (!heatBathEnergy.isInitialized() || heatBathEnergy.getSize() == 0) {
        if(useDouble){
            std::vector<cl_double> one(1);
            heatBathEnergy.initialize<cl_double>(cl, 1, "heatBathEnergy");
            heatBathEnergy.upload(one);
        } else {
            std::vector<cl_float> one(1);
            heatBathEnergy.initialize<cl_float>(cl, 1, "heatBathEnergy");
            heatBathEnergy.upload(one);
        }
    }

    cl.clearBuffer(heatBathEnergy);

    if (absChainIsValid) {
9060
9061
9062
        int numDOFs = nhc.getNumDegreesOfFreedom();
        double temperature = nhc.getTemperature();
        double frequency = nhc.getCollisionFrequency();
9063
9064
9065
9066
9067
9068
9069
9070
9071
9072
9073
9074
9075
9076
9077
9078
9079
        double kT = BOLTZ * temperature;
        float kTfloat = (float) kT;
        float frequencyFloat = (float) frequency;

        computeHeatBathEnergyKernel.setArg<cl::Buffer>(0, heatBathEnergy.getDeviceBuffer());
        computeHeatBathEnergyKernel.setArg<cl_int>(1, chainLength);
        computeHeatBathEnergyKernel.setArg<cl_int>(2, numDOFs); 
        if (useDouble)
            computeHeatBathEnergyKernel.setArg<cl_double>(3, kT);
        else
            computeHeatBathEnergyKernel.setArg<cl_float>(3, kTfloat);
        computeHeatBathEnergyKernel.setArg<cl_float>(4, frequencyFloat);
        computeHeatBathEnergyKernel.setArg<cl::Buffer>(5, chainState[2*chainID].getDeviceBuffer());
        cl.executeKernel(computeHeatBathEnergyKernel, 1, 1);
    }
    if (relChainIsValid) {
        int numDOFs = 3 * nhc.getThermostatedPairs().size();
9080
9081
        double temperature = nhc.getRelativeTemperature();
        double frequency = nhc.getRelativeCollisionFrequency();
9082
9083
9084
9085
9086
9087
9088
9089
9090
9091
9092
9093
9094
9095
9096
9097
9098
9099
9100
9101
9102
9103
9104
9105
9106
9107
9108
9109
9110
9111
        double kT = BOLTZ * temperature;
        float kTfloat = (float) kT;
        float frequencyFloat = (float) frequency;

        computeHeatBathEnergyKernel.setArg<cl::Buffer>(0, heatBathEnergy.getDeviceBuffer());
        computeHeatBathEnergyKernel.setArg<cl_int>(1, chainLength);
        computeHeatBathEnergyKernel.setArg<cl_int>(2, numDOFs); 
        if (useDouble)
            computeHeatBathEnergyKernel.setArg<cl_double>(3, kT);
        else
            computeHeatBathEnergyKernel.setArg<cl_float>(3, kTfloat);
        computeHeatBathEnergyKernel.setArg<cl_float>(4, frequencyFloat);
        computeHeatBathEnergyKernel.setArg<cl::Buffer>(5, chainState[2*chainID+1].getDeviceBuffer());
        cl.executeKernel(computeHeatBathEnergyKernel, 1, 1);
    }


    void * pinnedBuffer = cl.getPinnedBuffer();
    heatBathEnergy.download(pinnedBuffer);
    if (useDouble){
        return *((double*) pinnedBuffer);
    } else {
        return *((float*) pinnedBuffer);
    }
}

std::pair<double, double> OpenCLNoseHooverChainKernel::computeMaskedKineticEnergy(ContextImpl& context, const NoseHooverChain &nhc, bool downloadValue) {

    bool useDouble = cl.getUseDoublePrecision() || cl.getUseMixedPrecision();

9112
    int chainID = nhc.getChainID();
9113
9114
9115
9116
9117
9118
9119
9120
9121
9122
9123
9124
9125
9126
9127
9128
9129
9130
9131
9132
9133
9134
9135
9136
9137
9138
9139
9140
9141
9142
9143
9144
9145
9146
9147
9148
9149
9150
9151
9152
9153
    const auto & nhcAtoms = nhc.getThermostatedAtoms();
    const auto & nhcPairs = nhc.getThermostatedPairs();
    auto nAtoms = nhcAtoms.size();
    auto nPairs = nhcPairs.size();
    if (nAtoms) {
        if (!atomlists.count(chainID)) { 
            // We need to upload the OpenCL array
            atomlists[chainID] = OpenCLArray();
            atomlists[chainID].initialize<int>(cl, nAtoms, "atomlist" + std::to_string(chainID));
            atomlists[chainID].upload(nhcAtoms);
        }
        if (atomlists[chainID].getSize() != nAtoms) {
            throw OpenMMException("Number of atoms changed. Cannot be handled by the same Nose-Hoover thermostat.");
        }
    }
    if (nPairs) {
        if (!pairlists.count(chainID)) { 
            // We need to upload the OpenCL array
            pairlists[chainID] = OpenCLArray();
            pairlists[chainID].initialize<mm_int2>(cl, nPairs, "pairlist" + std::to_string(chainID));
            std::vector<mm_int2> int2vec;
            for(const auto &p : nhcPairs) int2vec.push_back(mm_int2(p.first, p.second));
            pairlists[chainID].upload(int2vec);
        }
        if (pairlists[chainID].getSize() != nPairs) {
            throw OpenMMException("Number of thermostated pairs changed. Cannot be handled by the same Nose-Hoover thermostat.");
        }
    }
    if (!kineticEnergyBuffer.isInitialized() || kineticEnergyBuffer.getSize() == 0) {
        if(useDouble){
            std::vector<mm_double2> zeros{{0,0}};
            kineticEnergyBuffer.initialize<mm_double2>(cl, 1, "kineticEnergyBuffer");
            kineticEnergyBuffer.upload(zeros);
        } else {
            std::vector<mm_float2> zeros{{0,0}};
            kineticEnergyBuffer.initialize<mm_float2>(cl, 1, "kineticEnergyBuffer");
            kineticEnergyBuffer.upload(zeros);
        }
    }
    cl.clearBuffer(cl.getEnergyBuffer());
    if (nAtoms) {
9154
        computeAtomsKineticEnergyKernel.setArg<cl::Buffer>(0, energyBuffer.getDeviceBuffer());
9155
9156
9157
9158
9159
9160
        computeAtomsKineticEnergyKernel.setArg<cl_int>(1, nAtoms);
        computeAtomsKineticEnergyKernel.setArg<cl::Buffer>(2, cl.getVelm().getDeviceBuffer());
        computeAtomsKineticEnergyKernel.setArg<cl::Buffer>(3, atomlists[chainID].getDeviceBuffer());
        cl.executeKernel(computeAtomsKineticEnergyKernel, nAtoms);
    }
    if (nPairs) {
9161
        computePairsKineticEnergyKernel.setArg<cl::Buffer>(0, energyBuffer.getDeviceBuffer());
9162
9163
9164
9165
9166
        computePairsKineticEnergyKernel.setArg<cl_int>(1, nPairs);
        computePairsKineticEnergyKernel.setArg<cl::Buffer>(2, cl.getVelm().getDeviceBuffer());
        computePairsKineticEnergyKernel.setArg<cl::Buffer>(3, pairlists[chainID].getDeviceBuffer());
        cl.executeKernel(computePairsKineticEnergyKernel, nPairs);
    }
9167
    int bufferSize = energyBuffer.getSize();
9168
9169
9170
    int workGroupSize  = cl.getDevice().getInfo<CL_DEVICE_MAX_WORK_GROUP_SIZE>();
    if (workGroupSize > 512)
        workGroupSize = 512;
9171
    reduceEnergyKernel.setArg<cl::Buffer>(0, energyBuffer.getDeviceBuffer());
9172
9173
9174
    reduceEnergyKernel.setArg<cl::Buffer>(1, kineticEnergyBuffer.getDeviceBuffer());
    reduceEnergyKernel.setArg<cl_int>(2, bufferSize);
    reduceEnergyKernel.setArg<cl_int>(3, workGroupSize);
9175
    reduceEnergyKernel.setArg(4, workGroupSize*energyBuffer.getElementSize(), NULL);
9176
9177
9178
9179
9180
9181
9182
9183
9184
9185
9186
9187
9188
9189
9190
9191
9192
9193
9194
9195
9196
9197
9198
    cl.executeKernel(reduceEnergyKernel, workGroupSize, workGroupSize);

    std::pair<double, double> KEs = {0, 0};
    if (downloadValue) {
        if (useDouble) {
            mm_double2 tmp;
            kineticEnergyBuffer.download(&tmp);
            KEs.first = tmp.x;
            KEs.second = tmp.y;
        } else {
            mm_float2 tmp;
            kineticEnergyBuffer.download(&tmp);
            KEs.first = tmp.x;
            KEs.second = tmp.y;
        }
    }
    return KEs;
}

void OpenCLNoseHooverChainKernel::scaleVelocities(ContextImpl& context, const NoseHooverChain &nhc, std::pair<double, double> scaleFactor) {
    // For now we assume that the atoms and pairs info is valid, because compute{Atoms|Pairs}KineticEnergy must have been
    // called before this kernel.  If that ever ceases to be true, some sanity checks are needed here.

9199
    int chainID = nhc.getChainID();
9200
9201
9202
9203
9204
9205
9206
9207
9208
9209
9210
9211
9212
9213
9214
9215
9216
    auto nAtoms = nhc.getThermostatedAtoms().size();
    auto nPairs = nhc.getThermostatedPairs().size();
    if(nAtoms) {
        scaleAtomsVelocitiesKernel.setArg<cl::Buffer>(0, scaleFactorBuffer.getDeviceBuffer());
        scaleAtomsVelocitiesKernel.setArg<cl_int>(1, nAtoms);
        scaleAtomsVelocitiesKernel.setArg<cl::Buffer>(2, cl.getVelm().getDeviceBuffer());
        scaleAtomsVelocitiesKernel.setArg<cl::Buffer>(3, atomlists[chainID].getDeviceBuffer());
        cl.executeKernel(scaleAtomsVelocitiesKernel, nAtoms);
    }
    if(nPairs) {
        scalePairsVelocitiesKernel.setArg<cl::Buffer>(0, scaleFactorBuffer.getDeviceBuffer());
        scalePairsVelocitiesKernel.setArg<cl_int>(1, nPairs);
        scalePairsVelocitiesKernel.setArg<cl::Buffer>(2, cl.getVelm().getDeviceBuffer());
        scalePairsVelocitiesKernel.setArg<cl::Buffer>(3, pairlists[chainID].getDeviceBuffer());
        cl.executeKernel(scalePairsVelocitiesKernel, nPairs);
    }
}
9217

9218
void OpenCLApplyMonteCarloBarostatKernel::initialize(const System& system, const Force& thermostat) {
peastman's avatar
peastman committed
9219
9220
    savedPositions.initialize(cl, cl.getPaddedNumAtoms(), cl.getUseDoublePrecision() ? sizeof(mm_double4) : sizeof(mm_float4), "savedPositions");
    savedForces.initialize(cl, cl.getPaddedNumAtoms(), cl.getUseDoublePrecision() ? sizeof(mm_double4) : sizeof(mm_float4), "savedForces");
9221
    cl::Program program = cl.createProgram(OpenCLKernelSources::monteCarloBarostat);
9222
    kernel = cl::Kernel(program, "scalePositions");
9223
9224
}

9225
void OpenCLApplyMonteCarloBarostatKernel::scaleCoordinates(ContextImpl& context, double scaleX, double scaleY, double scaleZ) {
9226
9227
9228
9229
9230
9231
9232
    if (!hasInitializedKernels) {
        hasInitializedKernels = true;

        // Create the arrays with the molecule definitions.

        vector<vector<int> > molecules = context.getMolecules();
        numMolecules = molecules.size();
peastman's avatar
peastman committed
9233
9234
9235
9236
        moleculeAtoms.initialize<int>(cl, cl.getNumAtoms(), "moleculeAtoms");
        moleculeStartIndex.initialize<int>(cl, numMolecules+1, "moleculeStartIndex");
        vector<int> atoms(moleculeAtoms.getSize());
        vector<int> startIndex(moleculeStartIndex.getSize());
9237
9238
9239
        int index = 0;
        for (int i = 0; i < numMolecules; i++) {
            startIndex[i] = index;
peastman's avatar
peastman committed
9240
9241
            for (int molecule : molecules[i])
                atoms[index++] = molecule;
9242
9243
        }
        startIndex[numMolecules] = index;
peastman's avatar
peastman committed
9244
9245
        moleculeAtoms.upload(atoms);
        moleculeStartIndex.upload(startIndex);
9246
9247
9248
9249

        // Initialize the kernel arguments.
        
        kernel.setArg<cl_int>(3, numMolecules);
9250
        kernel.setArg<cl::Buffer>(9, cl.getPosq().getDeviceBuffer());
peastman's avatar
peastman committed
9251
9252
        kernel.setArg<cl::Buffer>(10, moleculeAtoms.getDeviceBuffer());
        kernel.setArg<cl::Buffer>(11, moleculeStartIndex.getDeviceBuffer());
9253
    }
9254
    int bytesToCopy = cl.getPosq().getSize()*(cl.getUseDoublePrecision() ? sizeof(mm_double4) : sizeof(mm_float4));
peastman's avatar
peastman committed
9255
9256
    cl.getQueue().enqueueCopyBuffer(cl.getPosq().getDeviceBuffer(), savedPositions.getDeviceBuffer(), 0, 0, bytesToCopy);
    cl.getQueue().enqueueCopyBuffer(cl.getForce().getDeviceBuffer(), savedForces.getDeviceBuffer(), 0, 0, bytesToCopy);
9257
9258
9259
    kernel.setArg<cl_float>(0, (cl_float) scaleX);
    kernel.setArg<cl_float>(1, (cl_float) scaleY);
    kernel.setArg<cl_float>(2, (cl_float) scaleZ);
9260
    setPeriodicBoxArgs(cl, kernel, 4);
9261
    cl.executeKernel(kernel, cl.getNumAtoms());
peastman's avatar
peastman committed
9262
9263
    for (auto& offset : cl.getPosCellOffsets())
        offset = mm_int4(0, 0, 0, 0);
9264
    lastAtomOrder = cl.getAtomIndex();
9265
9266
9267
}

void OpenCLApplyMonteCarloBarostatKernel::restoreCoordinates(ContextImpl& context) {
9268
    int bytesToCopy = cl.getPosq().getSize()*(cl.getUseDoublePrecision() ? sizeof(mm_double4) : sizeof(mm_float4));
peastman's avatar
peastman committed
9269
9270
    cl.getQueue().enqueueCopyBuffer(savedPositions.getDeviceBuffer(), cl.getPosq().getDeviceBuffer(), 0, 0, bytesToCopy);
    cl.getQueue().enqueueCopyBuffer(savedForces.getDeviceBuffer(), cl.getForce().getDeviceBuffer(), 0, 0, bytesToCopy);
9271
9272
9273
9274
9275
}

void OpenCLRemoveCMMotionKernel::initialize(const System& system, const CMMotionRemover& force) {
    frequency = force.getFrequency();
    int numAtoms = cl.getNumAtoms();
peastman's avatar
peastman committed
9276
    cmMomentum.initialize<mm_float4>(cl, (numAtoms+OpenCLContext::ThreadBlockSize-1)/OpenCLContext::ThreadBlockSize, "cmMomentum");
9277
9278
9279
9280
    double totalMass = 0.0;
    for (int i = 0; i < numAtoms; i++)
        totalMass += system.getParticleMass(i);
    map<string, string> defines;
9281
    defines["INVERSE_TOTAL_MASS"] = cl.doubleToString(totalMass == 0 ? 0.0 : 1.0/totalMass);
9282
    cl::Program program = cl.createProgram(OpenCLKernelSources::removeCM, defines);
9283
9284
9285
    kernel1 = cl::Kernel(program, "calcCenterOfMassMomentum");
    kernel1.setArg<cl_int>(0, numAtoms);
    kernel1.setArg<cl::Buffer>(1, cl.getVelm().getDeviceBuffer());
peastman's avatar
peastman committed
9286
    kernel1.setArg<cl::Buffer>(2, cmMomentum.getDeviceBuffer());
9287
9288
9289
9290
    kernel1.setArg(3, OpenCLContext::ThreadBlockSize*sizeof(mm_float4), NULL);
    kernel2 = cl::Kernel(program, "removeCenterOfMassMomentum");
    kernel2.setArg<cl_int>(0, numAtoms);
    kernel2.setArg<cl::Buffer>(1, cl.getVelm().getDeviceBuffer());
peastman's avatar
peastman committed
9291
    kernel2.setArg<cl::Buffer>(2, cmMomentum.getDeviceBuffer());
9292
9293
9294
9295
9296
9297
9298
    kernel2.setArg(3, OpenCLContext::ThreadBlockSize*sizeof(mm_float4), NULL);
}

void OpenCLRemoveCMMotionKernel::execute(ContextImpl& context) {
    cl.executeKernel(kernel1, cl.getNumAtoms());
    cl.executeKernel(kernel2, cl.getNumAtoms());
}