OpenCLKernels.cpp 343 KB
Newer Older
1
2
3
4
5
6
7
8
/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
 * This is part of the OpenMM molecular simulation toolkit originating from   *
 * Simbios, the NIH National Center for Physics-Based Simulation of           *
 * Biological Structures at Stanford, funded under the NIH Roadmap for        *
 * Medical Research, grant U54 GM072970. See https://simtk.org.               *
 *                                                                            *
9
 * Portions copyright (c) 2008-2015 Stanford University and the Authors.      *
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
 * Authors: Peter Eastman                                                     *
 * Contributors:                                                              *
 *                                                                            *
 * This program is free software: you can redistribute it and/or modify       *
 * it under the terms of the GNU Lesser General Public License as published   *
 * by the Free Software Foundation, either version 3 of the License, or       *
 * (at your option) any later version.                                        *
 *                                                                            *
 * This program is distributed in the hope that it will be useful,            *
 * but WITHOUT ANY WARRANTY; without even the implied warranty of             *
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the              *
 * GNU Lesser General Public License for more details.                        *
 *                                                                            *
 * You should have received a copy of the GNU Lesser General Public License   *
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.      *
 * -------------------------------------------------------------------------- */

#include "OpenCLKernels.h"
28
#include "OpenCLForceInfo.h"
29
30
#include "openmm/LangevinIntegrator.h"
#include "openmm/Context.h"
31
#include "openmm/internal/AndersenThermostatImpl.h"
32
#include "openmm/internal/CMAPTorsionForceImpl.h"
33
#include "openmm/internal/ContextImpl.h"
34
#include "openmm/internal/CustomCompoundBondForceImpl.h"
35
#include "openmm/internal/CustomHbondForceImpl.h"
36
#include "openmm/internal/CustomManyParticleForceImpl.h"
37
#include "openmm/internal/CustomNonbondedForceImpl.h"
38
#include "openmm/internal/NonbondedForceImpl.h"
39
#include "openmm/internal/OSRngSeed.h"
Peter Eastman's avatar
Peter Eastman committed
40
#include "OpenCLBondedUtilities.h"
41
#include "OpenCLExpressionUtilities.h"
42
#include "OpenCLIntegrationUtilities.h"
43
#include "OpenCLNonbondedUtilities.h"
44
#include "OpenCLKernelSources.h"
45
#include "lepton/ExpressionTreeNode.h"
46
#include "lepton/Operation.h"
47
48
#include "lepton/Parser.h"
#include "lepton/ParsedExpression.h"
49
50
#include "SimTKOpenMMRealType.h"
#include "SimTKOpenMMUtilities.h"
51
#include <algorithm>
52
#include <cmath>
53
#include <set>
54
55
56

using namespace OpenMM;
using namespace std;
57
58
using Lepton::ExpressionTreeNode;
using Lepton::Operation;
59

60
61
62
63
64
65
66
static void setPosqCorrectionArg(OpenCLContext& cl, cl::Kernel& kernel, int index) {
    if (cl.getUseMixedPrecision())
        kernel.setArg<cl::Buffer>(index, cl.getPosqCorrection().getDeviceBuffer());
    else
        kernel.setArg<void*>(index, NULL);
}

67
68
69
70
71
72
73
static void setPeriodicBoxSizeArg(OpenCLContext& cl, cl::Kernel& kernel, int index) {
    if (cl.getUseDoublePrecision())
        kernel.setArg<mm_double4>(index, cl.getPeriodicBoxSizeDouble());
    else
        kernel.setArg<mm_float4>(index, cl.getPeriodicBoxSize());
}

74
static void setPeriodicBoxArgs(OpenCLContext& cl, cl::Kernel& kernel, int index) {
75
    if (cl.getUseDoublePrecision()) {
76
77
        kernel.setArg<mm_double4>(index++, cl.getPeriodicBoxSizeDouble());
        kernel.setArg<mm_double4>(index++, cl.getInvPeriodicBoxSizeDouble());
78
79
80
81
82
        kernel.setArg<mm_double4>(index++, cl.getPeriodicBoxVecXDouble());
        kernel.setArg<mm_double4>(index++, cl.getPeriodicBoxVecYDouble());
        kernel.setArg<mm_double4>(index, cl.getPeriodicBoxVecZDouble());
    }
    else {
83
84
        kernel.setArg<mm_float4>(index++, cl.getPeriodicBoxSize());
        kernel.setArg<mm_float4>(index++, cl.getInvPeriodicBoxSize());
85
86
87
88
        kernel.setArg<mm_float4>(index++, cl.getPeriodicBoxVecX());
        kernel.setArg<mm_float4>(index++, cl.getPeriodicBoxVecY());
        kernel.setArg<mm_float4>(index, cl.getPeriodicBoxVecZ());
    }
89
90
}

91
92
93
94
95
96
97
static bool isZeroExpression(const Lepton::ParsedExpression& expression) {
    const Lepton::Operation& op = expression.getRootNode().getOperation();
    if (op.getId() != Lepton::Operation::CONSTANT)
        return false;
    return (dynamic_cast<const Lepton::Operation::Constant&>(op).getValue() == 0.0);
}

98
99
100
101
102
103
104
105
106
107
108
109
110
111
static bool usesVariable(const Lepton::ExpressionTreeNode& node, const string& variable) {
    const Lepton::Operation& op = node.getOperation();
    if (op.getId() == Lepton::Operation::VARIABLE && op.getName() == variable)
        return true;
    for (int i = 0; i < (int) node.getChildren().size(); i++)
        if (usesVariable(node.getChildren()[i], variable))
            return true;
    return false;
}

static bool usesVariable(const Lepton::ParsedExpression& expression, const string& variable) {
    return usesVariable(expression.getRootNode(), variable);
}

112
113
114
115
static pair<ExpressionTreeNode, string> makeVariable(const string& name, const string& value) {
    return make_pair(ExpressionTreeNode(new Operation::Variable(name)), value);
}

116
void OpenCLCalcForcesAndEnergyKernel::initialize(const System& system) {
117
118
}

119
void OpenCLCalcForcesAndEnergyKernel::beginComputation(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) {
120
121
122
    cl.clearAutoclearBuffers();
    for (vector<OpenCLContext::ForcePreComputation*>::iterator iter = cl.getPreComputations().begin(); iter != cl.getPreComputations().end(); ++iter)
        (*iter)->computeForceAndEnergy(includeForces, includeEnergy, groups);
123
124
    OpenCLNonbondedUtilities& nb = cl.getNonbondedUtilities();
    bool includeNonbonded = ((groups&(1<<nb.getForceGroup())) != 0);
125
    cl.setComputeForceCount(cl.getComputeForceCount()+1);
126
127
    if (includeNonbonded)
        nb.prepareInteractions();
128
129
}

130
double OpenCLCalcForcesAndEnergyKernel::finishComputation(ContextImpl& context, bool includeForces, bool includeEnergy, int groups, bool& valid) {
131
132
133
    cl.getBondedUtilities().computeInteractions(groups);
    if ((groups&(1<<cl.getNonbondedUtilities().getForceGroup())) != 0)
        cl.getNonbondedUtilities().computeInteractions();
134
    cl.reduceForces();
135
136
137
    double sum = 0.0;
    for (vector<OpenCLContext::ForcePostComputation*>::iterator iter = cl.getPostComputations().begin(); iter != cl.getPostComputations().end(); ++iter)
        sum += (*iter)->computeForceAndEnergy(includeForces, includeEnergy, groups);
138
    cl.getIntegrationUtilities().distributeForcesFromVirtualSites();
139
    if (includeEnergy) {
140
        OpenCLArray& energyArray = cl.getEnergyBuffer();
141
142
143
144
145
146
147
148
149
150
151
152
        if (cl.getUseDoublePrecision()) {
            double* energy = (double*) cl.getPinnedBuffer();
            energyArray.download(energy);
            for (int i = 0; i < energyArray.getSize(); i++)
                sum += energy[i];
        }
        else {
            float* energy = (float*) cl.getPinnedBuffer();
            energyArray.download(energy);
            for (int i = 0; i < energyArray.getSize(); i++)
                sum += energy[i];
        }
153
    }
154
    return sum;
155
156
}

157
void OpenCLUpdateStateDataKernel::initialize(const System& system) {
158
159
}

160
double OpenCLUpdateStateDataKernel::getTime(const ContextImpl& context) const {
161
    return cl.getTime();
162
163
}

164
void OpenCLUpdateStateDataKernel::setTime(ContextImpl& context, double time) {
165
166
167
    vector<OpenCLContext*>& contexts = cl.getPlatformData().contexts;
    for (int i = 0; i < (int) contexts.size(); i++)
        contexts[i]->setTime(time);
168
169
}

Peter Eastman's avatar
Peter Eastman committed
170
void OpenCLUpdateStateDataKernel::getPositions(ContextImpl& context, vector<Vec3>& positions) {
171
    const vector<cl_int>& order = cl.getAtomIndex();
172
173
    int numParticles = context.getSystem().getNumParticles();
    positions.resize(numParticles);
Peter Eastman's avatar
Peter Eastman committed
174
175
    Vec3 boxVectors[3];
    cl.getPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
176
177
178
179
180
181
    if (cl.getUseDoublePrecision()) {
        mm_double4* posq = (mm_double4*) cl.getPinnedBuffer();
        cl.getPosq().download(posq);
        for (int i = 0; i < numParticles; ++i) {
            mm_double4 pos = posq[i];
            mm_int4 offset = cl.getPosCellOffsets()[i];
Peter Eastman's avatar
Peter Eastman committed
182
            positions[order[i]] = Vec3(pos.x, pos.y, pos.z)-boxVectors[0]*offset.x-boxVectors[1]*offset.y-boxVectors[2]*offset.z;
183
184
185
186
187
188
189
190
191
192
193
        }
    }
    else if (cl.getUseMixedPrecision()) {
        mm_float4* posq = (mm_float4*) cl.getPinnedBuffer();
        vector<mm_float4> posCorrection;
        cl.getPosq().download(posq);
        cl.getPosqCorrection().download(posCorrection);
        for (int i = 0; i < numParticles; ++i) {
            mm_float4 pos1 = posq[i];
            mm_float4 pos2 = posCorrection[i];
            mm_int4 offset = cl.getPosCellOffsets()[i];
Peter Eastman's avatar
Peter Eastman committed
194
            positions[order[i]] = Vec3((double)pos1.x+(double)pos2.x, (double)pos1.y+(double)pos2.y, (double)pos1.z+(double)pos2.z)-boxVectors[0]*offset.x-boxVectors[1]*offset.y-boxVectors[2]*offset.z;
195
196
197
198
199
200
201
202
        }
    }
    else {
        mm_float4* posq = (mm_float4*) cl.getPinnedBuffer();
        cl.getPosq().download(posq);
        for (int i = 0; i < numParticles; ++i) {
            mm_float4 pos = posq[i];
            mm_int4 offset = cl.getPosCellOffsets()[i];
Peter Eastman's avatar
Peter Eastman committed
203
            positions[order[i]] = Vec3(pos.x, pos.y, pos.z)-boxVectors[0]*offset.x-boxVectors[1]*offset.y-boxVectors[2]*offset.z;
204
        }
205
206
207
    }
}

Peter Eastman's avatar
Peter Eastman committed
208
void OpenCLUpdateStateDataKernel::setPositions(ContextImpl& context, const vector<Vec3>& positions) {
209
    const vector<cl_int>& order = cl.getAtomIndex();
210
    int numParticles = context.getSystem().getNumParticles();
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
    if (cl.getUseDoublePrecision()) {
        mm_double4* posq = (mm_double4*) cl.getPinnedBuffer();
        cl.getPosq().download(posq);
        for (int i = 0; i < numParticles; ++i) {
            mm_double4& pos = posq[i];
            const Vec3& p = positions[order[i]];
            pos.x = p[0];
            pos.y = p[1];
            pos.z = p[2];
        }
        for (int i = numParticles; i < cl.getPaddedNumAtoms(); i++)
            posq[i] = mm_double4(0.0, 0.0, 0.0, 0.0);
        cl.getPosq().upload(posq);
    }
    else {
        mm_float4* posq = (mm_float4*) cl.getPinnedBuffer();
        cl.getPosq().download(posq);
        for (int i = 0; i < numParticles; ++i) {
            mm_float4& pos = posq[i];
            const Vec3& p = positions[order[i]];
            pos.x = (cl_float) p[0];
            pos.y = (cl_float) p[1];
            pos.z = (cl_float) p[2];
        }
        for (int i = numParticles; i < cl.getPaddedNumAtoms(); i++)
            posq[i] = mm_float4(0.0f, 0.0f, 0.0f, 0.0f);
        cl.getPosq().upload(posq);
    }
    if (cl.getUseMixedPrecision()) {
        mm_float4* posCorrection = (mm_float4*) cl.getPinnedBuffer();
        for (int i = 0; i < numParticles; ++i) {
            mm_float4& c = posCorrection[i];
            const Vec3& p = positions[order[i]];
            c.x = (cl_float) (p[0]-(cl_float)p[0]);
            c.y = (cl_float) (p[1]-(cl_float)p[1]);
            c.z = (cl_float) (p[2]-(cl_float)p[2]);
            c.w = 0;
        }
        for (int i = numParticles; i < cl.getPaddedNumAtoms(); i++)
            posCorrection[i] = mm_float4(0.0f, 0.0f, 0.0f, 0.0f);
        cl.getPosqCorrection().upload(posCorrection);
    }
253
    for (int i = 0; i < (int) cl.getPosCellOffsets().size(); i++)
254
        cl.getPosCellOffsets()[i] = mm_int4(0, 0, 0, 0);
255
    cl.reorderAtoms();
256
257
}

Peter Eastman's avatar
Peter Eastman committed
258
void OpenCLUpdateStateDataKernel::getVelocities(ContextImpl& context, vector<Vec3>& velocities) {
259
    const vector<cl_int>& order = cl.getAtomIndex();
260
261
    int numParticles = context.getSystem().getNumParticles();
    velocities.resize(numParticles);
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
    if (cl.getUseDoublePrecision() || cl.getUseMixedPrecision()) {
        mm_double4* velm = (mm_double4*) cl.getPinnedBuffer();
        cl.getVelm().download(velm);
        for (int i = 0; i < numParticles; ++i) {
            mm_double4 vel = velm[i];
            mm_int4 offset = cl.getPosCellOffsets()[i];
            velocities[order[i]] = Vec3(vel.x, vel.y, vel.z);
        }
    }
    else {
        mm_float4* velm = (mm_float4*) cl.getPinnedBuffer();
        cl.getVelm().download(velm);
        for (int i = 0; i < numParticles; ++i) {
            mm_float4 vel = velm[i];
            mm_int4 offset = cl.getPosCellOffsets()[i];
            velocities[order[i]] = Vec3(vel.x, vel.y, vel.z);
        }
279
280
281
    }
}

Peter Eastman's avatar
Peter Eastman committed
282
void OpenCLUpdateStateDataKernel::setVelocities(ContextImpl& context, const vector<Vec3>& velocities) {
283
    const vector<cl_int>& order = cl.getAtomIndex();
284
    int numParticles = context.getSystem().getNumParticles();
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
    if (cl.getUseDoublePrecision() || cl.getUseMixedPrecision()) {
        mm_double4* velm = (mm_double4*) cl.getPinnedBuffer();
        cl.getVelm().download(velm);
        for (int i = 0; i < numParticles; ++i) {
            mm_double4& vel = velm[i];
            const Vec3& p = velocities[order[i]];
            vel.x = p[0];
            vel.y = p[1];
            vel.z = p[2];
        }
        for (int i = numParticles; i < cl.getPaddedNumAtoms(); i++)
            velm[i] = mm_double4(0.0, 0.0, 0.0, 0.0);
        cl.getVelm().upload(velm);
    }
    else {
        mm_float4* velm = (mm_float4*) cl.getPinnedBuffer();
        cl.getVelm().download(velm);
        for (int i = 0; i < numParticles; ++i) {
            mm_float4& vel = velm[i];
            const Vec3& p = velocities[order[i]];
            vel.x = p[0];
            vel.y = p[1];
            vel.z = p[2];
        }
        for (int i = numParticles; i < cl.getPaddedNumAtoms(); i++)
            velm[i] = mm_float4(0.0f, 0.0f, 0.0f, 0.0f);
        cl.getVelm().upload(velm);
    }
313
314
}

Peter Eastman's avatar
Peter Eastman committed
315
void OpenCLUpdateStateDataKernel::getForces(ContextImpl& context, vector<Vec3>& forces) {
316
    const vector<cl_int>& order = cl.getAtomIndex();
317
318
    int numParticles = context.getSystem().getNumParticles();
    forces.resize(numParticles);
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
    if (cl.getUseDoublePrecision()) {
        mm_double4* force = (mm_double4*) cl.getPinnedBuffer();
        cl.getForce().download(force);
        for (int i = 0; i < numParticles; ++i) {
            mm_double4 f = force[i];
            forces[order[i]] = Vec3(f.x, f.y, f.z);
        }
    }
    else {
        mm_float4* force = (mm_float4*) cl.getPinnedBuffer();
        cl.getForce().download(force);
        for (int i = 0; i < numParticles; ++i) {
            mm_float4 f = force[i];
            forces[order[i]] = Vec3(f.x, f.y, f.z);
        }
334
335
336
    }
}

337
void OpenCLUpdateStateDataKernel::getPeriodicBoxVectors(ContextImpl& context, Vec3& a, Vec3& b, Vec3& c) const {
338
    cl.getPeriodicBoxVectors(a, b, c);
339
340
341
}

void OpenCLUpdateStateDataKernel::setPeriodicBoxVectors(ContextImpl& context, const Vec3& a, const Vec3& b, const Vec3& c) const {
342
343
    vector<OpenCLContext*>& contexts = cl.getPlatformData().contexts;
    for (int i = 0; i < (int) contexts.size(); i++)
344
        contexts[i]->setPeriodicBoxVectors(a, b, c);
345
346
}

Peter Eastman's avatar
Peter Eastman committed
347
void OpenCLUpdateStateDataKernel::createCheckpoint(ContextImpl& context, ostream& stream) {
348
    int version = 2;
Peter Eastman's avatar
Peter Eastman committed
349
    stream.write((char*) &version, sizeof(int));
350
351
    int precision = (cl.getUseDoublePrecision() ? 2 : cl.getUseMixedPrecision() ? 1 : 0);
    stream.write((char*) &precision, sizeof(int));
Peter Eastman's avatar
Peter Eastman committed
352
353
    double time = cl.getTime();
    stream.write((char*) &time, sizeof(double));
Peter Eastman's avatar
Peter Eastman committed
354
355
    int stepCount = cl.getStepCount();
    stream.write((char*) &stepCount, sizeof(int));
356
357
    int stepsSinceReorder = cl.getStepsSinceReorder();
    stream.write((char*) &stepsSinceReorder, sizeof(int));
358
    char* buffer = (char*) cl.getPinnedBuffer();
359
360
361
362
363
364
365
366
    cl.getPosq().download(buffer);
    stream.write(buffer, cl.getPosq().getSize()*cl.getPosq().getElementSize());
    if (cl.getUseMixedPrecision()) {
        cl.getPosqCorrection().download(buffer);
        stream.write(buffer, cl.getPosqCorrection().getSize()*cl.getPosqCorrection().getElementSize());
    }
    cl.getVelm().download(buffer);
    stream.write(buffer, cl.getVelm().getSize()*cl.getVelm().getElementSize());
367
    stream.write((char*) &cl.getAtomIndex()[0], sizeof(cl_int)*cl.getAtomIndex().size());
Peter Eastman's avatar
Peter Eastman committed
368
    stream.write((char*) &cl.getPosCellOffsets()[0], sizeof(mm_int4)*cl.getPosCellOffsets().size());
369
370
371
    Vec3 boxVectors[3];
    cl.getPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
    stream.write((char*) boxVectors, 3*sizeof(Vec3));
Peter Eastman's avatar
Peter Eastman committed
372
    cl.getIntegrationUtilities().createCheckpoint(stream);
Peter Eastman's avatar
Peter Eastman committed
373
    SimTKOpenMMUtilities::createCheckpoint(stream);
Peter Eastman's avatar
Peter Eastman committed
374
375
376
377
378
}

void OpenCLUpdateStateDataKernel::loadCheckpoint(ContextImpl& context, istream& stream) {
    int version;
    stream.read((char*) &version, sizeof(int));
379
    if (version != 2)
Peter Eastman's avatar
Peter Eastman committed
380
        throw OpenMMException("Checkpoint was created with a different version of OpenMM");
381
382
383
384
385
    int precision;
    stream.read((char*) &precision, sizeof(int));
    int expectedPrecision = (cl.getUseDoublePrecision() ? 2 : cl.getUseMixedPrecision() ? 1 : 0);
    if (precision != expectedPrecision)
        throw OpenMMException("Checkpoint was created with a different numeric precision");
Peter Eastman's avatar
Peter Eastman committed
386
387
    double time;
    stream.read((char*) &time, sizeof(double));
388
    int stepCount, stepsSinceReorder;
Peter Eastman's avatar
Peter Eastman committed
389
    stream.read((char*) &stepCount, sizeof(int));
390
    stream.read((char*) &stepsSinceReorder, sizeof(int));
Peter Eastman's avatar
Peter Eastman committed
391
    vector<OpenCLContext*>& contexts = cl.getPlatformData().contexts;
Peter Eastman's avatar
Peter Eastman committed
392
    for (int i = 0; i < (int) contexts.size(); i++) {
Peter Eastman's avatar
Peter Eastman committed
393
        contexts[i]->setTime(time);
Peter Eastman's avatar
Peter Eastman committed
394
        contexts[i]->setStepCount(stepCount);
395
        contexts[i]->setStepsSinceReorder(stepsSinceReorder);
Peter Eastman's avatar
Peter Eastman committed
396
    }
397
    char* buffer = (char*) cl.getPinnedBuffer();
398
    stream.read(buffer, cl.getPosq().getSize()*cl.getPosq().getElementSize());
399
    cl.getPosq().upload(buffer);
400
401
402
403
404
    if (cl.getUseMixedPrecision()) {
        stream.read(buffer, cl.getPosqCorrection().getSize()*cl.getPosqCorrection().getElementSize());
        cl.getPosqCorrection().upload(buffer);
    }
    stream.read(buffer, cl.getVelm().getSize()*cl.getVelm().getElementSize());
405
406
407
    cl.getVelm().upload(buffer);
    stream.read((char*) &cl.getAtomIndex()[0], sizeof(cl_int)*cl.getAtomIndex().size());
    cl.getAtomIndexArray().upload(cl.getAtomIndex());
Peter Eastman's avatar
Peter Eastman committed
408
    stream.read((char*) &cl.getPosCellOffsets()[0], sizeof(mm_int4)*cl.getPosCellOffsets().size());
409
410
    Vec3 boxVectors[3];
    stream.read((char*) &boxVectors, 3*sizeof(Vec3));
Peter Eastman's avatar
Peter Eastman committed
411
    for (int i = 0; i < (int) contexts.size(); i++)
412
        contexts[i]->setPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
Peter Eastman's avatar
Peter Eastman committed
413
    cl.getIntegrationUtilities().loadCheckpoint(stream);
Peter Eastman's avatar
Peter Eastman committed
414
    SimTKOpenMMUtilities::loadCheckpoint(stream);
415
    for (int i = 0; i < (int) cl.getReorderListeners().size(); i++)
Peter Eastman's avatar
Peter Eastman committed
416
        cl.getReorderListeners()[i]->execute();
Peter Eastman's avatar
Peter Eastman committed
417
418
}

419
420
421
422
void OpenCLApplyConstraintsKernel::initialize(const System& system) {
}

void OpenCLApplyConstraintsKernel::apply(ContextImpl& context, double tol) {
423
424
425
    if (!hasInitializedKernel) {
        hasInitializedKernel = true;
        map<string, string> defines;
426
        defines["NUM_ATOMS"] = cl.intToString(cl.getNumAtoms());
427
428
429
        cl::Program program = cl.createProgram(OpenCLKernelSources::constraints, defines);
        applyDeltasKernel = cl::Kernel(program, "applyPositionDeltas");
        applyDeltasKernel.setArg<cl::Buffer>(0, cl.getPosq().getDeviceBuffer());
430
431
        setPosqCorrectionArg(cl, applyDeltasKernel, 1);
        applyDeltasKernel.setArg<cl::Buffer>(2, cl.getIntegrationUtilities().getPosDelta().getDeviceBuffer());
432
433
434
435
436
437
    }
    OpenCLIntegrationUtilities& integration = cl.getIntegrationUtilities();
    cl.clearBuffer(integration.getPosDelta());
    integration.applyConstraints(tol);
    cl.executeKernel(applyDeltasKernel, cl.getNumAtoms());
    integration.computeVirtualSites();
438
439
}

440
441
442
443
void OpenCLApplyConstraintsKernel::applyToVelocities(ContextImpl& context, double tol) {
    cl.getIntegrationUtilities().applyVelocityConstraints(tol);
}

444
445
446
447
448
449
450
void OpenCLVirtualSitesKernel::initialize(const System& system) {
}

void OpenCLVirtualSitesKernel::computePositions(ContextImpl& context) {
    cl.getIntegrationUtilities().computeVirtualSites();
}

451
class OpenCLHarmonicBondForceInfo : public OpenCLForceInfo {
452
public:
453
    OpenCLHarmonicBondForceInfo(const HarmonicBondForce& force) : OpenCLForceInfo(0), force(force) {
454
455
456
457
    }
    int getNumParticleGroups() {
        return force.getNumBonds();
    }
Peter Eastman's avatar
Peter Eastman committed
458
    void getParticlesInGroup(int index, vector<int>& particles) {
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
        int particle1, particle2;
        double length, k;
        force.getBondParameters(index, particle1, particle2, length, k);
        particles.resize(2);
        particles[0] = particle1;
        particles[1] = particle2;
    }
    bool areGroupsIdentical(int group1, int group2) {
        int particle1, particle2;
        double length1, length2, k1, k2;
        force.getBondParameters(group1, particle1, particle2, length1, k1);
        force.getBondParameters(group2, particle1, particle2, length2, k2);
        return (length1 == length2 && k1 == k2);
    }
private:
    const HarmonicBondForce& force;
};

477
478
479
480
481
OpenCLCalcHarmonicBondForceKernel::~OpenCLCalcHarmonicBondForceKernel() {
    if (params != NULL)
        delete params;
}

482
void OpenCLCalcHarmonicBondForceKernel::initialize(const System& system, const HarmonicBondForce& force) {
483
484
485
486
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*force.getNumBonds()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*force.getNumBonds()/numContexts;
    numBonds = endIndex-startIndex;
487
488
    if (numBonds == 0)
        return;
Peter Eastman's avatar
Peter Eastman committed
489
    vector<vector<int> > atoms(numBonds, vector<int>(2));
490
    params = OpenCLArray::create<mm_float2>(cl, numBonds, "bondParams");
491
492
493
    vector<mm_float2> paramVector(numBonds);
    for (int i = 0; i < numBonds; i++) {
        double length, k;
Peter Eastman's avatar
Peter Eastman committed
494
        force.getBondParameters(startIndex+i, atoms[i][0], atoms[i][1], length, k);
495
        paramVector[i] = mm_float2((cl_float) length, (cl_float) k);
496
497
    }
    params->upload(paramVector);
Peter Eastman's avatar
Peter Eastman committed
498
    map<string, string> replacements;
499
    replacements["COMPUTE_FORCE"] = OpenCLKernelSources::harmonicBondForce;
Peter Eastman's avatar
Peter Eastman committed
500
    replacements["PARAMS"] = cl.getBondedUtilities().addArgument(params->getDeviceBuffer(), "float2");
501
502
    cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::bondForce, replacements), force.getForceGroup());
    cl.addForce(new OpenCLHarmonicBondForceInfo(force));
503
504
}

505
double OpenCLCalcHarmonicBondForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
506
507
    return 0.0;
}
508

509
510
511
512
513
514
void OpenCLCalcHarmonicBondForceKernel::copyParametersToContext(ContextImpl& context, const HarmonicBondForce& force) {
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*force.getNumBonds()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*force.getNumBonds()/numContexts;
    if (numBonds != endIndex-startIndex)
        throw OpenMMException("updateParametersInContext: The number of bonds has changed");
515
516
    if (numBonds == 0)
        return;
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
    
    // Record the per-bond parameters.
    
    vector<mm_float2> paramVector(numBonds);
    for (int i = 0; i < numBonds; i++) {
        int atom1, atom2;
        double length, k;
        force.getBondParameters(startIndex+i, atom1, atom2, length, k);
        paramVector[i] = mm_float2((cl_float) length, (cl_float) k);
    }
    params->upload(paramVector);
    
    // Mark that the current reordering may be invalid.
    
    cl.invalidateMolecules();
}

534
535
class OpenCLCustomBondForceInfo : public OpenCLForceInfo {
public:
Peter Eastman's avatar
Peter Eastman committed
536
    OpenCLCustomBondForceInfo(const CustomBondForce& force) : OpenCLForceInfo(0), force(force) {
537
538
539
540
    }
    int getNumParticleGroups() {
        return force.getNumBonds();
    }
Peter Eastman's avatar
Peter Eastman committed
541
    void getParticlesInGroup(int index, vector<int>& particles) {
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
        int particle1, particle2;
        vector<double> parameters;
        force.getBondParameters(index, particle1, particle2, parameters);
        particles.resize(2);
        particles[0] = particle1;
        particles[1] = particle2;
    }
    bool areGroupsIdentical(int group1, int group2) {
        int particle1, particle2;
        vector<double> parameters1, parameters2;
        force.getBondParameters(group1, particle1, particle2, parameters1);
        force.getBondParameters(group2, particle1, particle2, parameters2);
        for (int i = 0; i < (int) parameters1.size(); i++)
            if (parameters1[i] != parameters2[i])
                return false;
        return true;
    }
private:
    const CustomBondForce& force;
};

OpenCLCalcCustomBondForceKernel::~OpenCLCalcCustomBondForceKernel() {
    if (params != NULL)
        delete params;
    if (globals != NULL)
        delete globals;
}

void OpenCLCalcCustomBondForceKernel::initialize(const System& system, const CustomBondForce& force) {
571
572
573
574
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*force.getNumBonds()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*force.getNumBonds()/numContexts;
    numBonds = endIndex-startIndex;
575
576
    if (numBonds == 0)
        return;
577
    vector<vector<int> > atoms(numBonds, vector<int>(2));
578
579
    params = new OpenCLParameterSet(cl, force.getNumPerBondParameters(), numBonds, "customBondParams");
    vector<vector<cl_float> > paramVector(numBonds);
580
581
    for (int i = 0; i < numBonds; i++) {
        vector<double> parameters;
582
        force.getBondParameters(startIndex+i, atoms[i][0], atoms[i][1], parameters);
583
        paramVector[i].resize(parameters.size());
584
        for (int j = 0; j < (int) parameters.size(); j++)
585
            paramVector[i][j] = (cl_float) parameters[j];
586
    }
587
    params->setParameterValues(paramVector);
Peter Eastman's avatar
Peter Eastman committed
588
    cl.addForce(new OpenCLCustomBondForceInfo(force));
589
590
591
592
593
594
595
596
597
598
599
600
601

    // Record information for the expressions.

    globalParamNames.resize(force.getNumGlobalParameters());
    globalParamValues.resize(force.getNumGlobalParameters());
    for (int i = 0; i < force.getNumGlobalParameters(); i++) {
        globalParamNames[i] = force.getGlobalParameterName(i);
        globalParamValues[i] = (cl_float) force.getGlobalParameterDefaultValue(i);
    }
    Lepton::ParsedExpression energyExpression = Lepton::Parser::parse(force.getEnergyFunction()).optimize();
    Lepton::ParsedExpression forceExpression = energyExpression.differentiate("r").optimize();
    map<string, Lepton::ParsedExpression> expressions;
    expressions["energy += "] = energyExpression;
602
    expressions["real dEdR = "] = forceExpression;
603
604
605
606
607
608
609

    // Create the kernels.

    map<string, string> variables;
    variables["r"] = "r";
    for (int i = 0; i < force.getNumPerBondParameters(); i++) {
        const string& name = force.getPerBondParameterName(i);
610
        variables[name] = "bondParams"+params->getParameterSuffix(i);
611
    }
612
    if (force.getNumGlobalParameters() > 0) {
613
        globals = OpenCLArray::create<cl_float>(cl, force.getNumGlobalParameters(), "customBondGlobals", CL_MEM_READ_ONLY);
614
615
616
617
        globals->upload(globalParamValues);
        string argName = cl.getBondedUtilities().addArgument(globals->getDeviceBuffer(), "float");
        for (int i = 0; i < force.getNumGlobalParameters(); i++) {
            const string& name = force.getGlobalParameterName(i);
618
            string value = argName+"["+cl.intToString(i)+"]";
619
620
            variables[name] = value;
        }
621
622
    }
    stringstream compute;
623
624
    for (int i = 0; i < (int) params->getBuffers().size(); i++) {
        const OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i];
625
626
        string argName = cl.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType());
        compute<<buffer.getType()<<" bondParams"<<(i+1)<<" = "<<argName<<"[index];\n";
627
    }
peastman's avatar
peastman committed
628
629
    vector<const TabulatedFunction*> functions;
    vector<pair<string, string> > functionNames;
630
    compute << cl.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp");
631
    map<string, string> replacements;
632
    replacements["COMPUTE_FORCE"] = compute.str();
633
    cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::bondForce, replacements), force.getForceGroup());
634
635
}

636
double OpenCLCalcCustomBondForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
637
638
    if (globals != NULL) {
        bool changed = false;
639
        for (int i = 0; i < (int) globalParamNames.size(); i++) {
640
641
642
643
644
645
646
647
648
649
650
            cl_float value = (cl_float) context.getParameter(globalParamNames[i]);
            if (value != globalParamValues[i])
                changed = true;
            globalParamValues[i] = value;
        }
        if (changed)
            globals->upload(globalParamValues);
    }
    return 0.0;
}

651
652
653
654
655
656
void OpenCLCalcCustomBondForceKernel::copyParametersToContext(ContextImpl& context, const CustomBondForce& force) {
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*force.getNumBonds()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*force.getNumBonds()/numContexts;
    if (numBonds != endIndex-startIndex)
        throw OpenMMException("updateParametersInContext: The number of bonds has changed");
657
658
    if (numBonds == 0)
        return;
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
    
    // Record the per-bond parameters.
    
    vector<vector<cl_float> > paramVector(numBonds);
    vector<double> parameters;
    for (int i = 0; i < numBonds; i++) {
        int atom1, atom2;
        force.getBondParameters(startIndex+i, atom1, atom2, parameters);
        paramVector[i].resize(parameters.size());
        for (int j = 0; j < (int) parameters.size(); j++)
            paramVector[i][j] = (cl_float) parameters[j];
    }
    params->setParameterValues(paramVector);
    
    // Mark that the current reordering may be invalid.
    
    cl.invalidateMolecules();
}

678
class OpenCLHarmonicAngleForceInfo : public OpenCLForceInfo {
679
public:
680
    OpenCLHarmonicAngleForceInfo(const HarmonicAngleForce& force) : OpenCLForceInfo(0), force(force) {
681
682
683
684
    }
    int getNumParticleGroups() {
        return force.getNumAngles();
    }
Peter Eastman's avatar
Peter Eastman committed
685
    void getParticlesInGroup(int index, vector<int>& particles) {
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
        int particle1, particle2, particle3;
        double angle, k;
        force.getAngleParameters(index, particle1, particle2, particle3, angle, k);
        particles.resize(3);
        particles[0] = particle1;
        particles[1] = particle2;
        particles[2] = particle3;
    }
    bool areGroupsIdentical(int group1, int group2) {
        int particle1, particle2, particle3;
        double angle1, angle2, k1, k2;
        force.getAngleParameters(group1, particle1, particle2, particle3, angle1, k1);
        force.getAngleParameters(group2, particle1, particle2, particle3, angle2, k2);
        return (angle1 == angle2 && k1 == k2);
    }
private:
    const HarmonicAngleForce& force;
};

OpenCLCalcHarmonicAngleForceKernel::~OpenCLCalcHarmonicAngleForceKernel() {
    if (params != NULL)
        delete params;
}

void OpenCLCalcHarmonicAngleForceKernel::initialize(const System& system, const HarmonicAngleForce& force) {
711
712
713
714
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*force.getNumAngles()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*force.getNumAngles()/numContexts;
    numAngles = endIndex-startIndex;
715
716
    if (numAngles == 0)
        return;
Peter Eastman's avatar
Peter Eastman committed
717
    vector<vector<int> > atoms(numAngles, vector<int>(3));
718
    params = OpenCLArray::create<mm_float2>(cl, numAngles, "angleParams");
719
720
721
    vector<mm_float2> paramVector(numAngles);
    for (int i = 0; i < numAngles; i++) {
        double angle, k;
Peter Eastman's avatar
Peter Eastman committed
722
        force.getAngleParameters(startIndex+i, atoms[i][0], atoms[i][1], atoms[i][2], angle, k);
723
        paramVector[i] = mm_float2((cl_float) angle, (cl_float) k);
724
725
726

    }
    params->upload(paramVector);
Peter Eastman's avatar
Peter Eastman committed
727
    map<string, string> replacements;
728
    replacements["COMPUTE_FORCE"] = OpenCLKernelSources::harmonicAngleForce;
Peter Eastman's avatar
Peter Eastman committed
729
    replacements["PARAMS"] = cl.getBondedUtilities().addArgument(params->getDeviceBuffer(), "float2");
730
731
    cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::angleForce, replacements), force.getForceGroup());
    cl.addForce(new OpenCLHarmonicAngleForceInfo(force));
732
733
}

734
double OpenCLCalcHarmonicAngleForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
735
736
737
    return 0.0;
}

738
739
740
741
742
743
void OpenCLCalcHarmonicAngleForceKernel::copyParametersToContext(ContextImpl& context, const HarmonicAngleForce& force) {
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*force.getNumAngles()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*force.getNumAngles()/numContexts;
    if (numAngles != endIndex-startIndex)
        throw OpenMMException("updateParametersInContext: The number of angles has changed");
744
745
    if (numAngles == 0)
        return;
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
    
    // Record the per-angle parameters.
    
    vector<mm_float2> paramVector(numAngles);
    for (int i = 0; i < numAngles; i++) {
        int atom1, atom2, atom3;
        double angle, k;
        force.getAngleParameters(startIndex+i, atom1, atom2, atom3, angle, k);
        paramVector[i] = mm_float2((cl_float) angle, (cl_float) k);
    }
    params->upload(paramVector);
    
    // Mark that the current reordering may be invalid.
    
    cl.invalidateMolecules();
}

763
764
class OpenCLCustomAngleForceInfo : public OpenCLForceInfo {
public:
Peter Eastman's avatar
Peter Eastman committed
765
    OpenCLCustomAngleForceInfo(const CustomAngleForce& force) : OpenCLForceInfo(0), force(force) {
766
767
768
769
    }
    int getNumParticleGroups() {
        return force.getNumAngles();
    }
Peter Eastman's avatar
Peter Eastman committed
770
    void getParticlesInGroup(int index, vector<int>& particles) {
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
        int particle1, particle2, particle3;
        vector<double> parameters;
        force.getAngleParameters(index, particle1, particle2, particle3, parameters);
        particles.resize(3);
        particles[0] = particle1;
        particles[1] = particle2;
        particles[2] = particle3;
    }
    bool areGroupsIdentical(int group1, int group2) {
        int particle1, particle2, particle3;
        vector<double> parameters1, parameters2;
        force.getAngleParameters(group1, particle1, particle2, particle3, parameters1);
        force.getAngleParameters(group2, particle1, particle2, particle3, parameters2);
        for (int i = 0; i < (int) parameters1.size(); i++)
            if (parameters1[i] != parameters2[i])
                return false;
        return true;
    }
private:
    const CustomAngleForce& force;
};

OpenCLCalcCustomAngleForceKernel::~OpenCLCalcCustomAngleForceKernel() {
    if (params != NULL)
        delete params;
    if (globals != NULL)
        delete globals;
}

void OpenCLCalcCustomAngleForceKernel::initialize(const System& system, const CustomAngleForce& force) {
801
802
803
804
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*force.getNumAngles()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*force.getNumAngles()/numContexts;
    numAngles = endIndex-startIndex;
805
806
    if (numAngles == 0)
        return;
807
    vector<vector<int> > atoms(numAngles, vector<int>(3));
808
809
810
811
    params = new OpenCLParameterSet(cl, force.getNumPerAngleParameters(), numAngles, "customAngleParams");
    vector<vector<cl_float> > paramVector(numAngles);
    for (int i = 0; i < numAngles; i++) {
        vector<double> parameters;
812
        force.getAngleParameters(startIndex+i, atoms[i][0], atoms[i][1], atoms[i][2], parameters);
813
814
815
816
817
        paramVector[i].resize(parameters.size());
        for (int j = 0; j < (int) parameters.size(); j++)
            paramVector[i][j] = (cl_float) parameters[j];
    }
    params->setParameterValues(paramVector);
Peter Eastman's avatar
Peter Eastman committed
818
    cl.addForce(new OpenCLCustomAngleForceInfo(force));
819
820
821
822
823
824
825
826
827
828
829
830
831

    // Record information for the expressions.

    globalParamNames.resize(force.getNumGlobalParameters());
    globalParamValues.resize(force.getNumGlobalParameters());
    for (int i = 0; i < force.getNumGlobalParameters(); i++) {
        globalParamNames[i] = force.getGlobalParameterName(i);
        globalParamValues[i] = (cl_float) force.getGlobalParameterDefaultValue(i);
    }
    Lepton::ParsedExpression energyExpression = Lepton::Parser::parse(force.getEnergyFunction()).optimize();
    Lepton::ParsedExpression forceExpression = energyExpression.differentiate("theta").optimize();
    map<string, Lepton::ParsedExpression> expressions;
    expressions["energy += "] = energyExpression;
832
    expressions["real dEdAngle = "] = forceExpression;
833
834
835
836
837
838
839
840
841

    // Create the kernels.

    map<string, string> variables;
    variables["theta"] = "theta";
    for (int i = 0; i < force.getNumPerAngleParameters(); i++) {
        const string& name = force.getPerAngleParameterName(i);
        variables[name] = "angleParams"+params->getParameterSuffix(i);
    }
842
    if (force.getNumGlobalParameters() > 0) {
843
        globals = OpenCLArray::create<cl_float>(cl, force.getNumGlobalParameters(), "customAngleGlobals", CL_MEM_READ_ONLY);
844
845
846
847
        globals->upload(globalParamValues);
        string argName = cl.getBondedUtilities().addArgument(globals->getDeviceBuffer(), "float");
        for (int i = 0; i < force.getNumGlobalParameters(); i++) {
            const string& name = force.getGlobalParameterName(i);
848
            string value = argName+"["+cl.intToString(i)+"]";
849
850
            variables[name] = value;
        }
851
852
853
854
    }
    stringstream compute;
    for (int i = 0; i < (int) params->getBuffers().size(); i++) {
        const OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i];
855
856
        string argName = cl.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType());
        compute<<buffer.getType()<<" angleParams"<<(i+1)<<" = "<<argName<<"[index];\n";
857
    }
peastman's avatar
peastman committed
858
859
    vector<const TabulatedFunction*> functions;
    vector<pair<string, string> > functionNames;
860
    compute << cl.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp");
861
862
    map<string, string> replacements;
    replacements["COMPUTE_FORCE"] = compute.str();
863
    cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::angleForce, replacements), force.getForceGroup());
864
865
}

866
double OpenCLCalcCustomAngleForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
867
868
869
870
871
872
873
874
875
876
877
878
879
880
    if (globals != NULL) {
        bool changed = false;
        for (int i = 0; i < (int) globalParamNames.size(); i++) {
            cl_float value = (cl_float) context.getParameter(globalParamNames[i]);
            if (value != globalParamValues[i])
                changed = true;
            globalParamValues[i] = value;
        }
        if (changed)
            globals->upload(globalParamValues);
    }
    return 0.0;
}

881
882
883
884
885
886
void OpenCLCalcCustomAngleForceKernel::copyParametersToContext(ContextImpl& context, const CustomAngleForce& force) {
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*force.getNumAngles()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*force.getNumAngles()/numContexts;
    if (numAngles != endIndex-startIndex)
        throw OpenMMException("updateParametersInContext: The number of angles has changed");
887
888
    if (numAngles == 0)
        return;
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
    
    // Record the per-angle parameters.
    
    vector<vector<cl_float> > paramVector(numAngles);
    vector<double> parameters;
    for (int i = 0; i < numAngles; i++) {
        int atom1, atom2, atom3;
        force.getAngleParameters(startIndex+i, atom1, atom2, atom3, parameters);
        paramVector[i].resize(parameters.size());
        for (int j = 0; j < (int) parameters.size(); j++)
            paramVector[i][j] = (cl_float) parameters[j];
    }
    params->setParameterValues(paramVector);
    
    // Mark that the current reordering may be invalid.
    
    cl.invalidateMolecules();
}

908
909
class OpenCLPeriodicTorsionForceInfo : public OpenCLForceInfo {
public:
Peter Eastman's avatar
Peter Eastman committed
910
    OpenCLPeriodicTorsionForceInfo(const PeriodicTorsionForce& force) : OpenCLForceInfo(0), force(force) {
911
912
913
914
    }
    int getNumParticleGroups() {
        return force.getNumTorsions();
    }
Peter Eastman's avatar
Peter Eastman committed
915
    void getParticlesInGroup(int index, vector<int>& particles) {
916
917
918
919
920
921
922
923
924
925
926
927
928
        int particle1, particle2, particle3, particle4, periodicity;
        double phase, k;
        force.getTorsionParameters(index, particle1, particle2, particle3, particle4, periodicity, phase, k);
        particles.resize(4);
        particles[0] = particle1;
        particles[1] = particle2;
        particles[2] = particle3;
        particles[3] = particle4;
    }
    bool areGroupsIdentical(int group1, int group2) {
        int particle1, particle2, particle3, particle4, periodicity1, periodicity2;
        double phase1, phase2, k1, k2;
        force.getTorsionParameters(group1, particle1, particle2, particle3, particle4, periodicity1, phase1, k1);
929
        force.getTorsionParameters(group2, particle1, particle2, particle3, particle4, periodicity2, phase2, k2);
930
931
932
933
934
935
936
937
938
939
940
941
        return (periodicity1 == periodicity2 && phase1 == phase2 && k1 == k2);
    }
private:
    const PeriodicTorsionForce& force;
};

OpenCLCalcPeriodicTorsionForceKernel::~OpenCLCalcPeriodicTorsionForceKernel() {
    if (params != NULL)
        delete params;
}

void OpenCLCalcPeriodicTorsionForceKernel::initialize(const System& system, const PeriodicTorsionForce& force) {
942
943
944
945
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*force.getNumTorsions()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*force.getNumTorsions()/numContexts;
    numTorsions = endIndex-startIndex;
946
947
    if (numTorsions == 0)
        return;
Peter Eastman's avatar
Peter Eastman committed
948
    vector<vector<int> > atoms(numTorsions, vector<int>(4));
949
    params = OpenCLArray::create<mm_float4>(cl, numTorsions, "periodicTorsionParams");
950
951
    vector<mm_float4> paramVector(numTorsions);
    for (int i = 0; i < numTorsions; i++) {
Peter Eastman's avatar
Peter Eastman committed
952
        int periodicity;
953
        double phase, k;
Peter Eastman's avatar
Peter Eastman committed
954
        force.getTorsionParameters(startIndex+i, atoms[i][0], atoms[i][1], atoms[i][2], atoms[i][3], periodicity, phase, k);
955
        paramVector[i] = mm_float4((cl_float) k, (cl_float) phase, (cl_float) periodicity, 0.0f);
956
957
    }
    params->upload(paramVector);
Peter Eastman's avatar
Peter Eastman committed
958
    map<string, string> replacements;
959
    replacements["COMPUTE_FORCE"] = OpenCLKernelSources::periodicTorsionForce;
Peter Eastman's avatar
Peter Eastman committed
960
    replacements["PARAMS"] = cl.getBondedUtilities().addArgument(params->getDeviceBuffer(), "float4");
961
    cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::torsionForce, replacements), force.getForceGroup());
Peter Eastman's avatar
Peter Eastman committed
962
    cl.addForce(new OpenCLPeriodicTorsionForceInfo(force));
963
964
}

965
double OpenCLCalcPeriodicTorsionForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
966
967
968
    return 0.0;
}

969
970
971
972
973
974
void OpenCLCalcPeriodicTorsionForceKernel::copyParametersToContext(ContextImpl& context, const PeriodicTorsionForce& force) {
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*force.getNumTorsions()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*force.getNumTorsions()/numContexts;
    if (numTorsions != endIndex-startIndex)
        throw OpenMMException("updateParametersInContext: The number of torsions has changed");
975
976
    if (numTorsions == 0)
        return;
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
    
    // Record the per-torsion parameters.
    
    vector<mm_float4> paramVector(numTorsions);
    for (int i = 0; i < numTorsions; i++) {
        int atom1, atom2, atom3, atom4, periodicity;
        double phase, k;
        force.getTorsionParameters(startIndex+i, atom1, atom2, atom3, atom4, periodicity, phase, k);
        paramVector[i] = mm_float4((cl_float) k, (cl_float) phase, (cl_float) periodicity, 0.0f);
    }
    params->upload(paramVector);
    
    // Mark that the current reordering may be invalid.
    
    cl.invalidateMolecules();
}

994
995
class OpenCLRBTorsionForceInfo : public OpenCLForceInfo {
public:
Peter Eastman's avatar
Peter Eastman committed
996
    OpenCLRBTorsionForceInfo(const RBTorsionForce& force) : OpenCLForceInfo(0), force(force) {
997
998
999
1000
    }
    int getNumParticleGroups() {
        return force.getNumTorsions();
    }
Peter Eastman's avatar
Peter Eastman committed
1001
    void getParticlesInGroup(int index, vector<int>& particles) {
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
        int particle1, particle2, particle3, particle4;
        double c0, c1, c2, c3, c4, c5;
        force.getTorsionParameters(index, particle1, particle2, particle3, particle4, c0, c1, c2, c3, c4, c5);
        particles.resize(4);
        particles[0] = particle1;
        particles[1] = particle2;
        particles[2] = particle3;
        particles[3] = particle4;
    }
    bool areGroupsIdentical(int group1, int group2) {
        int particle1, particle2, particle3, particle4;
        double c0a, c0b, c1a, c1b, c2a, c2b, c3a, c3b, c4a, c4b, c5a, c5b;
        force.getTorsionParameters(group1, particle1, particle2, particle3, particle4, c0a, c1a, c2a, c3a, c4a, c5a);
1015
        force.getTorsionParameters(group2, particle1, particle2, particle3, particle4, c0b, c1b, c2b, c3b, c4b, c5b);
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
        return (c0a == c0b && c1a == c1b && c2a == c2b && c3a == c3b && c4a == c4b && c5a == c5b);
    }
private:
    const RBTorsionForce& force;
};

OpenCLCalcRBTorsionForceKernel::~OpenCLCalcRBTorsionForceKernel() {
    if (params != NULL)
        delete params;
}

void OpenCLCalcRBTorsionForceKernel::initialize(const System& system, const RBTorsionForce& force) {
1028
1029
1030
1031
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*force.getNumTorsions()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*force.getNumTorsions()/numContexts;
    numTorsions = endIndex-startIndex;
1032
1033
    if (numTorsions == 0)
        return;
Peter Eastman's avatar
Peter Eastman committed
1034
    vector<vector<int> > atoms(numTorsions, vector<int>(4));
1035
    params = OpenCLArray::create<mm_float8>(cl, numTorsions, "rbTorsionParams");
1036
1037
1038
    vector<mm_float8> paramVector(numTorsions);
    for (int i = 0; i < numTorsions; i++) {
        double c0, c1, c2, c3, c4, c5;
Peter Eastman's avatar
Peter Eastman committed
1039
        force.getTorsionParameters(startIndex+i, atoms[i][0], atoms[i][1], atoms[i][2], atoms[i][3], c0, c1, c2, c3, c4, c5);
1040
        paramVector[i] = mm_float8((cl_float) c0, (cl_float) c1, (cl_float) c2, (cl_float) c3, (cl_float) c4, (cl_float) c5, 0.0f, 0.0f);
1041
1042
1043

    }
    params->upload(paramVector);
Peter Eastman's avatar
Peter Eastman committed
1044
    map<string, string> replacements;
1045
    replacements["COMPUTE_FORCE"] = OpenCLKernelSources::rbTorsionForce;
Peter Eastman's avatar
Peter Eastman committed
1046
    replacements["PARAMS"] = cl.getBondedUtilities().addArgument(params->getDeviceBuffer(), "float8");
1047
    cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::torsionForce, replacements), force.getForceGroup());
Peter Eastman's avatar
Peter Eastman committed
1048
    cl.addForce(new OpenCLRBTorsionForceInfo(force));
1049
1050
}

1051
double OpenCLCalcRBTorsionForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
1052
1053
1054
    return 0.0;
}

1055
1056
1057
1058
1059
1060
void OpenCLCalcRBTorsionForceKernel::copyParametersToContext(ContextImpl& context, const RBTorsionForce& force) {
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*force.getNumTorsions()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*force.getNumTorsions()/numContexts;
    if (numTorsions != endIndex-startIndex)
        throw OpenMMException("updateParametersInContext: The number of torsions has changed");
1061
1062
    if (numTorsions == 0)
        return;
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
    
    // Record the per-torsion parameters.
    
    vector<mm_float8> paramVector(numTorsions);
    for (int i = 0; i < numTorsions; i++) {
        int atom1, atom2, atom3, atom4;
        double c0, c1, c2, c3, c4, c5;
        force.getTorsionParameters(startIndex+i, atom1, atom2, atom3, atom4, c0, c1, c2, c3, c4, c5);
        paramVector[i] = mm_float8((cl_float) c0, (cl_float) c1, (cl_float) c2, (cl_float) c3, (cl_float) c4, (cl_float) c5, 0.0f, 0.0f);
    }
    params->upload(paramVector);
    
    // Mark that the current reordering may be invalid.
    
    cl.invalidateMolecules();
}

1080
1081
class OpenCLCMAPTorsionForceInfo : public OpenCLForceInfo {
public:
Peter Eastman's avatar
Peter Eastman committed
1082
    OpenCLCMAPTorsionForceInfo(const CMAPTorsionForce& force) : OpenCLForceInfo(0), force(force) {
1083
1084
1085
1086
    }
    int getNumParticleGroups() {
        return force.getNumTorsions();
    }
Peter Eastman's avatar
Peter Eastman committed
1087
    void getParticlesInGroup(int index, vector<int>& particles) {
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
        int map, a1, a2, a3, a4, b1, b2, b3, b4;
        force.getTorsionParameters(index, map, a1, a2, a3, a4, b1, b2, b3, b4);
        particles.resize(8);
        particles[0] = a1;
        particles[1] = a2;
        particles[2] = a3;
        particles[3] = a4;
        particles[4] = b1;
        particles[5] = b2;
        particles[6] = b3;
        particles[7] = b4;
    }
    bool areGroupsIdentical(int group1, int group2) {
        int map1, map2, a1, a2, a3, a4, b1, b2, b3, b4;
        force.getTorsionParameters(group1, map1, a1, a2, a3, a4, b1, b2, b3, b4);
        force.getTorsionParameters(group2, map2, a1, a2, a3, a4, b1, b2, b3, b4);
        return (map1 == map2);
    }
private:
    const CMAPTorsionForce& force;
};

OpenCLCalcCMAPTorsionForceKernel::~OpenCLCalcCMAPTorsionForceKernel() {
    if (coefficients != NULL)
        delete coefficients;
    if (mapPositions != NULL)
        delete mapPositions;
    if (torsionMaps != NULL)
        delete torsionMaps;
}

void OpenCLCalcCMAPTorsionForceKernel::initialize(const System& system, const CMAPTorsionForce& force) {
1120
1121
1122
1123
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*force.getNumTorsions()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*force.getNumTorsions()/numContexts;
    numTorsions = endIndex-startIndex;
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
    if (numTorsions == 0)
        return;
    int numMaps = force.getNumMaps();
    vector<mm_float4> coeffVec;
    vector<mm_int2> mapPositionsVec(numMaps);
    vector<double> energy;
    vector<vector<double> > c;
    int currentPosition = 0;
    for (int i = 0; i < numMaps; i++) {
        int size;
        force.getMapParameters(i, size, energy);
        CMAPTorsionForceImpl::calcMapDerivatives(size, energy, c);
        mapPositionsVec[i] = mm_int2(currentPosition, size);
        currentPosition += 4*size*size;
        for (int j = 0; j < size*size; j++) {
1139
1140
1141
1142
            coeffVec.push_back(mm_float4((float) c[j][0], (float) c[j][1], (float) c[j][2], (float) c[j][3]));
            coeffVec.push_back(mm_float4((float) c[j][4], (float) c[j][5], (float) c[j][6], (float) c[j][7]));
            coeffVec.push_back(mm_float4((float) c[j][8], (float) c[j][9], (float) c[j][10], (float) c[j][11]));
            coeffVec.push_back(mm_float4((float) c[j][12], (float) c[j][13], (float) c[j][14], (float) c[j][15]));
1143
1144
        }
    }
1145
    vector<vector<int> > atoms(numTorsions, vector<int>(8));
1146
    vector<cl_int> torsionMapsVec(numTorsions);
1147
1148
    for (int i = 0; i < numTorsions; i++)
        force.getTorsionParameters(startIndex+i, torsionMapsVec[i], atoms[i][0], atoms[i][1], atoms[i][2], atoms[i][3], atoms[i][4], atoms[i][5], atoms[i][6], atoms[i][7]);
1149
1150
1151
    coefficients = OpenCLArray::create<mm_float4>(cl, coeffVec.size(), "cmapTorsionCoefficients");
    mapPositions = OpenCLArray::create<mm_int2>(cl, numMaps, "cmapTorsionMapPositions");
    torsionMaps = OpenCLArray::create<cl_int>(cl, numTorsions, "cmapTorsionMaps");
1152
1153
1154
    coefficients->upload(coeffVec);
    mapPositions->upload(mapPositionsVec);
    torsionMaps->upload(torsionMapsVec);
1155
1156
1157
1158
    map<string, string> replacements;
    replacements["COEFF"] = cl.getBondedUtilities().addArgument(coefficients->getDeviceBuffer(), "float4");
    replacements["MAP_POS"] = cl.getBondedUtilities().addArgument(mapPositions->getDeviceBuffer(), "int2");
    replacements["MAPS"] = cl.getBondedUtilities().addArgument(torsionMaps->getDeviceBuffer(), "int");
1159
    cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::cmapTorsionForce, replacements), force.getForceGroup());
Peter Eastman's avatar
Peter Eastman committed
1160
    cl.addForce(new OpenCLCMAPTorsionForceInfo(force));
1161
1162
}

1163
double OpenCLCalcCMAPTorsionForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
1164
1165
1166
    return 0.0;
}

1167
1168
1169
void OpenCLCalcCMAPTorsionForceKernel::copyParametersToContext(ContextImpl& context, const CMAPTorsionForce& force) {
}

1170
1171
class OpenCLCustomTorsionForceInfo : public OpenCLForceInfo {
public:
Peter Eastman's avatar
Peter Eastman committed
1172
    OpenCLCustomTorsionForceInfo(const CustomTorsionForce& force) : OpenCLForceInfo(0), force(force) {
1173
1174
1175
1176
    }
    int getNumParticleGroups() {
        return force.getNumTorsions();
    }
Peter Eastman's avatar
Peter Eastman committed
1177
    void getParticlesInGroup(int index, vector<int>& particles) {
1178
1179
1180
        int particle1, particle2, particle3, particle4;
        vector<double> parameters;
        force.getTorsionParameters(index, particle1, particle2, particle3, particle4, parameters);
Peter Eastman's avatar
Bug fix  
Peter Eastman committed
1181
        particles.resize(4);
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
        particles[0] = particle1;
        particles[1] = particle2;
        particles[2] = particle3;
        particles[3] = particle4;
    }
    bool areGroupsIdentical(int group1, int group2) {
        int particle1, particle2, particle3, particle4;
        vector<double> parameters1, parameters2;
        force.getTorsionParameters(group1, particle1, particle2, particle3, particle4, parameters1);
        force.getTorsionParameters(group2, particle1, particle2, particle3, particle4, parameters2);
        for (int i = 0; i < (int) parameters1.size(); i++)
            if (parameters1[i] != parameters2[i])
                return false;
        return true;
    }
private:
    const CustomTorsionForce& force;
};

OpenCLCalcCustomTorsionForceKernel::~OpenCLCalcCustomTorsionForceKernel() {
    if (params != NULL)
        delete params;
    if (globals != NULL)
        delete globals;
}

void OpenCLCalcCustomTorsionForceKernel::initialize(const System& system, const CustomTorsionForce& force) {
1209
1210
1211
1212
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*force.getNumTorsions()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*force.getNumTorsions()/numContexts;
    numTorsions = endIndex-startIndex;
1213
1214
    if (numTorsions == 0)
        return;
1215
    vector<vector<int> > atoms(numTorsions, vector<int>(4));
1216
1217
1218
1219
    params = new OpenCLParameterSet(cl, force.getNumPerTorsionParameters(), numTorsions, "customTorsionParams");
    vector<vector<cl_float> > paramVector(numTorsions);
    for (int i = 0; i < numTorsions; i++) {
        vector<double> parameters;
1220
        force.getTorsionParameters(startIndex+i, atoms[i][0], atoms[i][1], atoms[i][2], atoms[i][3], parameters);
1221
1222
1223
1224
1225
        paramVector[i].resize(parameters.size());
        for (int j = 0; j < (int) parameters.size(); j++)
            paramVector[i][j] = (cl_float) parameters[j];
    }
    params->setParameterValues(paramVector);
Peter Eastman's avatar
Peter Eastman committed
1226
    cl.addForce(new OpenCLCustomTorsionForceInfo(force));
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239

    // Record information for the expressions.

    globalParamNames.resize(force.getNumGlobalParameters());
    globalParamValues.resize(force.getNumGlobalParameters());
    for (int i = 0; i < force.getNumGlobalParameters(); i++) {
        globalParamNames[i] = force.getGlobalParameterName(i);
        globalParamValues[i] = (cl_float) force.getGlobalParameterDefaultValue(i);
    }
    Lepton::ParsedExpression energyExpression = Lepton::Parser::parse(force.getEnergyFunction()).optimize();
    Lepton::ParsedExpression forceExpression = energyExpression.differentiate("theta").optimize();
    map<string, Lepton::ParsedExpression> expressions;
    expressions["energy += "] = energyExpression;
1240
    expressions["real dEdAngle = "] = forceExpression;
1241
1242
1243
1244
1245
1246
1247
1248
1249

    // Create the kernels.

    map<string, string> variables;
    variables["theta"] = "theta";
    for (int i = 0; i < force.getNumPerTorsionParameters(); i++) {
        const string& name = force.getPerTorsionParameterName(i);
        variables[name] = "torsionParams"+params->getParameterSuffix(i);
    }
1250
    if (force.getNumGlobalParameters() > 0) {
1251
        globals = OpenCLArray::create<cl_float>(cl, force.getNumGlobalParameters(), "customTorsionGlobals", CL_MEM_READ_ONLY);
1252
1253
1254
1255
        globals->upload(globalParamValues);
        string argName = cl.getBondedUtilities().addArgument(globals->getDeviceBuffer(), "float");
        for (int i = 0; i < force.getNumGlobalParameters(); i++) {
            const string& name = force.getGlobalParameterName(i);
1256
            string value = argName+"["+cl.intToString(i)+"]";
1257
1258
            variables[name] = value;
        }
1259
1260
1261
1262
    }
    stringstream compute;
    for (int i = 0; i < (int) params->getBuffers().size(); i++) {
        const OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i];
1263
1264
        string argName = cl.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType());
        compute<<buffer.getType()<<" torsionParams"<<(i+1)<<" = "<<argName<<"[index];\n";
1265
    }
peastman's avatar
peastman committed
1266
1267
    vector<const TabulatedFunction*> functions;
    vector<pair<string, string> > functionNames;
1268
    compute << cl.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp");
1269
1270
    map<string, string> replacements;
    replacements["COMPUTE_FORCE"] = compute.str();
1271
    cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::torsionForce, replacements), force.getForceGroup());
1272
1273
}

1274
double OpenCLCalcCustomTorsionForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
    if (globals != NULL) {
        bool changed = false;
        for (int i = 0; i < (int) globalParamNames.size(); i++) {
            cl_float value = (cl_float) context.getParameter(globalParamNames[i]);
            if (value != globalParamValues[i])
                changed = true;
            globalParamValues[i] = value;
        }
        if (changed)
            globals->upload(globalParamValues);
    }
    return 0.0;
}

1289
1290
1291
1292
1293
1294
void OpenCLCalcCustomTorsionForceKernel::copyParametersToContext(ContextImpl& context, const CustomTorsionForce& force) {
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*force.getNumTorsions()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*force.getNumTorsions()/numContexts;
    if (numTorsions != endIndex-startIndex)
        throw OpenMMException("updateParametersInContext: The number of torsions has changed");
1295
1296
    if (numTorsions == 0)
        return;
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
    
    // Record the per-torsion parameters.
    
    vector<vector<cl_float> > paramVector(numTorsions);
    vector<double> parameters;
    for (int i = 0; i < numTorsions; i++) {
        int atom1, atom2, atom3, atom4;
        force.getTorsionParameters(startIndex+i, atom1, atom2, atom3, atom4, parameters);
        paramVector[i].resize(parameters.size());
        for (int j = 0; j < (int) parameters.size(); j++)
            paramVector[i][j] = (cl_float) parameters[j];
    }
    params->setParameterValues(paramVector);
    
    // Mark that the current reordering may be invalid.
    
    cl.invalidateMolecules();
}

1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
class OpenCLNonbondedForceInfo : public OpenCLForceInfo {
public:
    OpenCLNonbondedForceInfo(int requiredBuffers, const NonbondedForce& force) : OpenCLForceInfo(requiredBuffers), force(force) {
    }
    bool areParticlesIdentical(int particle1, int particle2) {
        double charge1, charge2, sigma1, sigma2, epsilon1, epsilon2;
        force.getParticleParameters(particle1, charge1, sigma1, epsilon1);
        force.getParticleParameters(particle2, charge2, sigma2, epsilon2);
        return (charge1 == charge2 && sigma1 == sigma2 && epsilon1 == epsilon2);
    }
    int getNumParticleGroups() {
        return force.getNumExceptions();
    }
Peter Eastman's avatar
Peter Eastman committed
1329
    void getParticlesInGroup(int index, vector<int>& particles) {
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
        int particle1, particle2;
        double chargeProd, sigma, epsilon;
        force.getExceptionParameters(index, particle1, particle2, chargeProd, sigma, epsilon);
        particles.resize(2);
        particles[0] = particle1;
        particles[1] = particle2;
    }
    bool areGroupsIdentical(int group1, int group2) {
        int particle1, particle2;
        double chargeProd1, chargeProd2, sigma1, sigma2, epsilon1, epsilon2;
        force.getExceptionParameters(group1, particle1, particle2, chargeProd1, sigma1, epsilon1);
        force.getExceptionParameters(group2, particle1, particle2, chargeProd2, sigma2, epsilon2);
        return (chargeProd1 == chargeProd2 && sigma1 == sigma2 && epsilon1 == epsilon2);
    }
private:
    const NonbondedForce& force;
};

1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
class OpenCLCalcNonbondedForceKernel::PmeIO : public CalcPmeReciprocalForceKernel::IO {
public:
    PmeIO(OpenCLContext& cl, cl::Kernel addForcesKernel) : cl(cl), addForcesKernel(addForcesKernel), forceTemp(NULL) {
        forceTemp = OpenCLArray::create<mm_float4>(cl, cl.getNumAtoms(), "PmeForce");
        addForcesKernel.setArg<cl::Buffer>(0, forceTemp->getDeviceBuffer());
    }
    ~PmeIO() {
        if (forceTemp != NULL)
            delete forceTemp;
    }
    float* getPosq() {
        cl.getPosq().download(posq);
        return (float*) &posq[0];
    }
    void setForce(float* force) {
        forceTemp->upload(force);
        addForcesKernel.setArg<cl::Buffer>(1, cl.getForce().getDeviceBuffer());
        cl.executeKernel(addForcesKernel, cl.getNumAtoms());
    }
private:
    OpenCLContext& cl;
    vector<mm_float4> posq;
    OpenCLArray* forceTemp;
    cl::Kernel addForcesKernel;
};

class OpenCLCalcNonbondedForceKernel::PmePreComputation : public OpenCLContext::ForcePreComputation {
public:
    PmePreComputation(OpenCLContext& cl, Kernel& pme, CalcPmeReciprocalForceKernel::IO& io) : cl(cl), pme(pme), io(io) {
    }
    void computeForceAndEnergy(bool includeForces, bool includeEnergy, int groups) {
peastman's avatar
peastman committed
1379
1380
        Vec3 boxVectors[3] = {Vec3(cl.getPeriodicBoxSize().x, 0, 0), Vec3(0, cl.getPeriodicBoxSize().y, 0), Vec3(0, 0, cl.getPeriodicBoxSize().z)};
        pme.getAs<CalcPmeReciprocalForceKernel>().beginComputation(io, boxVectors, includeEnergy);
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
    }
private:
    OpenCLContext& cl;
    Kernel pme;
    CalcPmeReciprocalForceKernel::IO& io;
};

class OpenCLCalcNonbondedForceKernel::PmePostComputation : public OpenCLContext::ForcePostComputation {
public:
    PmePostComputation(Kernel& pme, CalcPmeReciprocalForceKernel::IO& io) : pme(pme), io(io) {
    }
    double computeForceAndEnergy(bool includeForces, bool includeEnergy, int groups) {
        return pme.getAs<CalcPmeReciprocalForceKernel>().finishComputation(io);
    }
private:
    Kernel pme;
    CalcPmeReciprocalForceKernel::IO& io;
};

1400
1401
class OpenCLCalcNonbondedForceKernel::SyncQueuePreComputation : public OpenCLContext::ForcePreComputation {
public:
peastman's avatar
Bug fix  
peastman committed
1402
    SyncQueuePreComputation(OpenCLContext& cl, cl::CommandQueue queue, int forceGroup) : cl(cl), queue(queue), events(1), forceGroup(forceGroup) {
1403
1404
    }
    void computeForceAndEnergy(bool includeForces, bool includeEnergy, int groups) {
peastman's avatar
Bug fix  
peastman committed
1405
1406
1407
1408
        if ((groups&(1<<forceGroup)) != 0) {
            cl.getQueue().enqueueMarker(&events[0]);
            queue.enqueueWaitForEvents(events);
        }
1409
1410
1411
1412
1413
    }
private:
    OpenCLContext& cl;
    cl::CommandQueue queue;
    vector<cl::Event> events;
peastman's avatar
Bug fix  
peastman committed
1414
    int forceGroup;
1415
1416
1417
1418
};

class OpenCLCalcNonbondedForceKernel::SyncQueuePostComputation : public OpenCLContext::ForcePostComputation {
public:
peastman's avatar
Bug fix  
peastman committed
1419
    SyncQueuePostComputation(OpenCLContext& cl, cl::Event& event, int forceGroup) : cl(cl), event(event), events(1), forceGroup(forceGroup) {
1420
1421
    }
    double computeForceAndEnergy(bool includeForces, bool includeEnergy, int groups) {
peastman's avatar
Bug fix  
peastman committed
1422
1423
1424
1425
        if ((groups&(1<<forceGroup)) != 0) {
            events[0] = event;
            cl.getQueue().enqueueWaitForEvents(events);
        }
1426
1427
1428
1429
1430
1431
        return 0.0;
    }
private:
    OpenCLContext& cl;
    cl::Event& event;
    vector<cl::Event> events;
peastman's avatar
Bug fix  
peastman committed
1432
    int forceGroup;
1433
1434
};

1435
1436
1437
OpenCLCalcNonbondedForceKernel::~OpenCLCalcNonbondedForceKernel() {
    if (sigmaEpsilon != NULL)
        delete sigmaEpsilon;
1438
1439
    if (exceptionParams != NULL)
        delete exceptionParams;
1440
1441
1442
1443
    if (cosSinSums != NULL)
        delete cosSinSums;
    if (pmeGrid != NULL)
        delete pmeGrid;
Peter Eastman's avatar
Peter Eastman committed
1444
1445
    if (pmeGrid2 != NULL)
        delete pmeGrid2;
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
    if (pmeBsplineModuliX != NULL)
        delete pmeBsplineModuliX;
    if (pmeBsplineModuliY != NULL)
        delete pmeBsplineModuliY;
    if (pmeBsplineModuliZ != NULL)
        delete pmeBsplineModuliZ;
    if (pmeBsplineTheta != NULL)
        delete pmeBsplineTheta;
    if (pmeAtomRange != NULL)
        delete pmeAtomRange;
    if (pmeAtomGridIndex != NULL)
        delete pmeAtomGridIndex;
    if (sort != NULL)
        delete sort;
    if (fft != NULL)
        delete fft;
1462
1463
    if (pmeio != NULL)
        delete pmeio;
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
}

void OpenCLCalcNonbondedForceKernel::initialize(const System& system, const NonbondedForce& force) {

    // Identify which exceptions are 1-4 interactions.

    vector<pair<int, int> > exclusions;
    vector<int> exceptions;
    for (int i = 0; i < force.getNumExceptions(); i++) {
        int particle1, particle2;
        double chargeProd, sigma, epsilon;
        force.getExceptionParameters(i, particle1, particle2, chargeProd, sigma, epsilon);
        exclusions.push_back(pair<int, int>(particle1, particle2));
        if (chargeProd != 0.0 || epsilon != 0.0)
            exceptions.push_back(i);
    }

    // Initialize nonbonded interactions.

    int numParticles = force.getNumParticles();
1484
    sigmaEpsilon = OpenCLArray::create<mm_float2>(cl, cl.getPaddedNumAtoms(), "sigmaEpsilon");
1485
1486
1487
    vector<mm_float4> posqf(cl.getPaddedNumAtoms(), mm_float4(0,0,0,0));
    vector<mm_double4> posqd(cl.getPaddedNumAtoms(), mm_double4(0,0,0,0));
    vector<mm_float2> sigmaEpsilonVector(cl.getPaddedNumAtoms(), mm_float2(0,0));
1488
    vector<vector<int> > exclusionList(numParticles);
1489
    double sumSquaredCharges = 0.0;
1490
1491
    hasCoulomb = false;
    hasLJ = false;
1492
1493
1494
    for (int i = 0; i < numParticles; i++) {
        double charge, sigma, epsilon;
        force.getParticleParameters(i, charge, sigma, epsilon);
1495
1496
1497
1498
        if (cl.getUseDoublePrecision())
            posqd[i] = mm_double4(0, 0, 0, charge);
        else
            posqf[i] = mm_float4(0, 0, 0, (float) charge);
1499
        sigmaEpsilonVector[i] = mm_float2((float) (0.5*sigma), (float) (2.0*sqrt(epsilon)));
1500
        exclusionList[i].push_back(i);
1501
        sumSquaredCharges += charge*charge;
1502
1503
1504
1505
        if (charge != 0.0)
            hasCoulomb = true;
        if (epsilon != 0.0)
            hasLJ = true;
1506
1507
1508
1509
1510
    }
    for (int i = 0; i < (int) exclusions.size(); i++) {
        exclusionList[exclusions[i].first].push_back(exclusions[i].second);
        exclusionList[exclusions[i].second].push_back(exclusions[i].first);
    }
1511
1512
1513
1514
    if (cl.getUseDoublePrecision())
        cl.getPosq().upload(posqd);
    else
        cl.getPosq().upload(posqf);
1515
1516
1517
    sigmaEpsilon->upload(sigmaEpsilonVector);
    bool useCutoff = (force.getNonbondedMethod() != NonbondedForce::NoCutoff);
    bool usePeriodic = (force.getNonbondedMethod() != NonbondedForce::NoCutoff && force.getNonbondedMethod() != NonbondedForce::CutoffNonPeriodic);
1518
    map<string, string> defines;
1519
1520
    defines["HAS_COULOMB"] = (hasCoulomb ? "1" : "0");
    defines["HAS_LENNARD_JONES"] = (hasLJ ? "1" : "0");
1521
    defines["USE_LJ_SWITCH"] = (useCutoff && force.getUseSwitchingFunction() ? "1" : "0");
1522
    if (useCutoff) {
1523
1524
        // Compute the reaction field constants.

1525
1526
        double reactionFieldK = pow(force.getCutoffDistance(), -3.0)*(force.getReactionFieldDielectric()-1.0)/(2.0*force.getReactionFieldDielectric()+1.0);
        double reactionFieldC = (1.0 / force.getCutoffDistance())*(3.0*force.getReactionFieldDielectric())/(2.0*force.getReactionFieldDielectric()+1.0);
1527
1528
        defines["REACTION_FIELD_K"] = cl.doubleToString(reactionFieldK);
        defines["REACTION_FIELD_C"] = cl.doubleToString(reactionFieldC);
1529
1530
1531
1532
1533
1534
1535
1536
1537
        
        // Compute the switching coefficients.
        
        if (force.getUseSwitchingFunction()) {
            defines["LJ_SWITCH_CUTOFF"] = cl.doubleToString(force.getSwitchingDistance());
            defines["LJ_SWITCH_C3"] = cl.doubleToString(10/pow(force.getSwitchingDistance()-force.getCutoffDistance(), 3.0));
            defines["LJ_SWITCH_C4"] = cl.doubleToString(15/pow(force.getSwitchingDistance()-force.getCutoffDistance(), 4.0));
            defines["LJ_SWITCH_C5"] = cl.doubleToString(6/pow(force.getSwitchingDistance()-force.getCutoffDistance(), 5.0));
        }
1538
    }
1539
    if (force.getUseDispersionCorrection() && cl.getContextIndex() == 0)
1540
1541
1542
        dispersionCoefficient = NonbondedForceImpl::calcDispersionCorrection(system, force);
    else
        dispersionCoefficient = 0.0;
1543
    alpha = 0;
1544
    ewaldSelfEnergy = 0.0;
1545
    if (force.getNonbondedMethod() == NonbondedForce::Ewald) {
1546
1547
1548
1549
        // Compute the Ewald parameters.

        int kmaxx, kmaxy, kmaxz;
        NonbondedForceImpl::calcEwaldParameters(system, force, alpha, kmaxx, kmaxy, kmaxz);
1550
1551
        defines["EWALD_ALPHA"] = cl.doubleToString(alpha);
        defines["TWO_OVER_SQRT_PI"] = cl.doubleToString(2.0/sqrt(M_PI));
1552
        defines["USE_EWALD"] = "1";
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
        if (cl.getContextIndex() == 0) {
            ewaldSelfEnergy = -ONE_4PI_EPS0*alpha*sumSquaredCharges/sqrt(M_PI);

            // Create the reciprocal space kernels.

            map<string, string> replacements;
            replacements["NUM_ATOMS"] = cl.intToString(numParticles);
            replacements["KMAX_X"] = cl.intToString(kmaxx);
            replacements["KMAX_Y"] = cl.intToString(kmaxy);
            replacements["KMAX_Z"] = cl.intToString(kmaxz);
            replacements["EXP_COEFFICIENT"] = cl.doubleToString(-1.0/(4.0*alpha*alpha));
            cl::Program program = cl.createProgram(OpenCLKernelSources::ewald, replacements);
            ewaldSumsKernel = cl::Kernel(program, "calculateEwaldCosSinSums");
            ewaldForcesKernel = cl::Kernel(program, "calculateEwaldForces");
            int elementSize = (cl.getUseDoublePrecision() ? sizeof(mm_double2) : sizeof(mm_float2));
            cosSinSums = new OpenCLArray(cl, (2*kmaxx-1)*(2*kmaxy-1)*(2*kmaxz-1), elementSize, "cosSinSums");
        }
    }
    else if (force.getNonbondedMethod() == NonbondedForce::PME) {
1572
1573
1574
1575
1576
1577
1578
        // Compute the PME parameters.

        int gridSizeX, gridSizeY, gridSizeZ;
        NonbondedForceImpl::calcPMEParameters(system, force, alpha, gridSizeX, gridSizeY, gridSizeZ);
        gridSizeX = OpenCLFFT3D::findLegalDimension(gridSizeX);
        gridSizeY = OpenCLFFT3D::findLegalDimension(gridSizeY);
        gridSizeZ = OpenCLFFT3D::findLegalDimension(gridSizeZ);
1579
1580
        defines["EWALD_ALPHA"] = cl.doubleToString(alpha);
        defines["TWO_OVER_SQRT_PI"] = cl.doubleToString(2.0/sqrt(M_PI));
1581
        defines["USE_EWALD"] = "1";
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
        if (cl.getContextIndex() == 0) {
            ewaldSelfEnergy = -ONE_4PI_EPS0*alpha*sumSquaredCharges/sqrt(M_PI);
            pmeDefines["PME_ORDER"] = cl.intToString(PmeOrder);
            pmeDefines["NUM_ATOMS"] = cl.intToString(numParticles);
            pmeDefines["RECIP_EXP_FACTOR"] = cl.doubleToString(M_PI*M_PI/(alpha*alpha));
            pmeDefines["GRID_SIZE_X"] = cl.intToString(gridSizeX);
            pmeDefines["GRID_SIZE_Y"] = cl.intToString(gridSizeY);
            pmeDefines["GRID_SIZE_Z"] = cl.intToString(gridSizeZ);
            pmeDefines["EPSILON_FACTOR"] = cl.doubleToString(sqrt(ONE_4PI_EPS0));
            bool deviceIsCpu = (cl.getDevice().getInfo<CL_DEVICE_TYPE>() == CL_DEVICE_TYPE_CPU);
            if (deviceIsCpu)
                pmeDefines["DEVICE_IS_CPU"] = "1";
            if (cl.getPlatformData().useCpuPme) {
                // Create the CPU PME kernel.

                try {
                    cpuPme = getPlatform().createKernel(CalcPmeReciprocalForceKernel::Name(), *cl.getPlatformData().context);
                    cpuPme.getAs<CalcPmeReciprocalForceKernel>().initialize(gridSizeX, gridSizeY, gridSizeZ, numParticles, alpha);
                    cl::Program program = cl.createProgram(OpenCLKernelSources::pme, pmeDefines);
                    cl::Kernel addForcesKernel = cl::Kernel(program, "addForces");
                    pmeio = new PmeIO(cl, addForcesKernel);
                    cl.addPreComputation(new PmePreComputation(cl, cpuPme, *pmeio));
                    cl.addPostComputation(new PmePostComputation(cpuPme, *pmeio));
1605
                }
1606
1607
                catch (OpenMMException& ex) {
                    // The CPU PME plugin isn't available.
1608
                }
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
            }
            if (pmeio == NULL) {
                // Create required data structures.

                int elementSize = (cl.getUseDoublePrecision() ? sizeof(double) : sizeof(float));
                pmeGrid = new OpenCLArray(cl, gridSizeX*gridSizeY*gridSizeZ, 2*elementSize, "pmeGrid");
                cl.addAutoclearBuffer(*pmeGrid);
                pmeGrid2 = new OpenCLArray(cl, gridSizeX*gridSizeY*gridSizeZ, 2*elementSize, "pmeGrid2");
                pmeBsplineModuliX = new OpenCLArray(cl, gridSizeX, elementSize, "pmeBsplineModuliX");
                pmeBsplineModuliY = new OpenCLArray(cl, gridSizeY, elementSize, "pmeBsplineModuliY");
                pmeBsplineModuliZ = new OpenCLArray(cl, gridSizeZ, elementSize, "pmeBsplineModuliZ");
                pmeBsplineTheta = new OpenCLArray(cl, PmeOrder*numParticles, 4*elementSize, "pmeBsplineTheta");
                pmeAtomRange = OpenCLArray::create<cl_int>(cl, gridSizeX*gridSizeY*gridSizeZ+1, "pmeAtomRange");
                pmeAtomGridIndex = OpenCLArray::create<mm_int2>(cl, numParticles, "pmeAtomGridIndex");
                sort = new OpenCLSort(cl, new SortTrait(), cl.getNumAtoms());
                fft = new OpenCLFFT3D(cl, gridSizeX, gridSizeY, gridSizeZ);
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
                string vendor = cl.getDevice().getInfo<CL_DEVICE_VENDOR>();
                usePmeQueue = (vendor.size() >= 6 && vendor.substr(0, 6) == "NVIDIA");
                if (usePmeQueue) {
                    pmeQueue = cl::CommandQueue(cl.getContext(), cl.getDevice());
                    int recipForceGroup = force.getReciprocalSpaceForceGroup();
                    if (recipForceGroup < 0)
                        recipForceGroup = force.getForceGroup();
                    cl.addPreComputation(new SyncQueuePreComputation(cl, pmeQueue, recipForceGroup));
                    cl.addPostComputation(new SyncQueuePostComputation(cl, pmeSyncEvent, recipForceGroup));
                }
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650

                // Initialize the b-spline moduli.

                int maxSize = max(max(gridSizeX, gridSizeY), gridSizeZ);
                vector<double> data(PmeOrder);
                vector<double> ddata(PmeOrder);
                vector<double> bsplines_data(maxSize);
                data[PmeOrder-1] = 0.0;
                data[1] = 0.0;
                data[0] = 1.0;
                for (int i = 3; i < PmeOrder; i++) {
                    double div = 1.0/(i-1.0);
                    data[i-1] = 0.0;
                    for (int j = 1; j < (i-1); j++)
                        data[i-j-1] = div*(j*data[i-j-2]+(i-j)*data[i-j-1]);
                    data[0] = div*data[0];
1651
                }
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682

                // Differentiate.

                ddata[0] = -data[0];
                for (int i = 1; i < PmeOrder; i++)
                    ddata[i] = data[i-1]-data[i];
                double div = 1.0/(PmeOrder-1);
                data[PmeOrder-1] = 0.0;
                for (int i = 1; i < (PmeOrder-1); i++)
                    data[PmeOrder-i-1] = div*(i*data[PmeOrder-i-2]+(PmeOrder-i)*data[PmeOrder-i-1]);
                data[0] = div*data[0];
                for (int i = 0; i < maxSize; i++)
                    bsplines_data[i] = 0.0;
                for (int i = 1; i <= PmeOrder; i++)
                    bsplines_data[i] = data[i-1];

                // Evaluate the actual bspline moduli for X/Y/Z.

                for(int dim = 0; dim < 3; dim++) {
                    int ndata = (dim == 0 ? gridSizeX : dim == 1 ? gridSizeY : gridSizeZ);
                    vector<cl_double> moduli(ndata);
                    for (int i = 0; i < ndata; i++) {
                        double sc = 0.0;
                        double ss = 0.0;
                        for (int j = 0; j < ndata; j++) {
                            double arg = (2.0*M_PI*i*j)/ndata;
                            sc += bsplines_data[j]*cos(arg);
                            ss += bsplines_data[j]*sin(arg);
                        }
                        moduli[i] = (float) (sc*sc+ss*ss);
                    }
1683
                    for (int i = 0; i < ndata; i++)
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
                    {
                        if (moduli[i] < 1.0e-7)
                            moduli[i] = (moduli[i-1]+moduli[i+1])*0.5f;
                    }
                    if (cl.getUseDoublePrecision()) {
                        if (dim == 0)
                            pmeBsplineModuliX->upload(moduli);
                        else if (dim == 1)
                            pmeBsplineModuliY->upload(moduli);
                        else
                            pmeBsplineModuliZ->upload(moduli);
                    }
                    else {
                        vector<float> modulif(ndata);
                        for (int i = 0; i < ndata; i++)
                            modulif[i] = (float) moduli[i];
                        if (dim == 0)
                            pmeBsplineModuliX->upload(modulif);
                        else if (dim == 1)
                            pmeBsplineModuliY->upload(modulif);
                        else
                            pmeBsplineModuliZ->upload(modulif);
                    }
1707
                }
1708
            }
1709
1710
        }
    }
1711
1712
1713

    // Add the interaction to the default nonbonded kernel.
    
1714
    string source = cl.replaceStrings(OpenCLKernelSources::coulombLennardJones, defines);
1715
    cl.getNonbondedUtilities().addInteraction(useCutoff, usePeriodic, true, force.getCutoffDistance(), exclusionList, source, force.getForceGroup());
Peter Eastman's avatar
Peter Eastman committed
1716
    if (hasLJ)
1717
        cl.getNonbondedUtilities().addParameter(OpenCLNonbondedUtilities::ParameterInfo("sigmaEpsilon", "float", 2, sizeof(cl_float2), sigmaEpsilon->getDeviceBuffer()));
1718

1719
    // Initialize the exceptions.
1720

1721
1722
1723
1724
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*exceptions.size()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*exceptions.size()/numContexts;
    int numExceptions = endIndex-startIndex;
1725
    if (numExceptions > 0) {
1726
        exceptionAtoms.resize(numExceptions);
Peter Eastman's avatar
Peter Eastman committed
1727
        vector<vector<int> > atoms(numExceptions, vector<int>(2));
1728
        exceptionParams = OpenCLArray::create<mm_float4>(cl, numExceptions, "exceptionParams");
1729
        vector<mm_float4> exceptionParamsVector(numExceptions);
1730
        for (int i = 0; i < numExceptions; i++) {
1731
            double chargeProd, sigma, epsilon;
Peter Eastman's avatar
Peter Eastman committed
1732
            force.getExceptionParameters(exceptions[startIndex+i], atoms[i][0], atoms[i][1], chargeProd, sigma, epsilon);
1733
            exceptionParamsVector[i] = mm_float4((float) (ONE_4PI_EPS0*chargeProd), (float) sigma, (float) (4.0*epsilon), 0.0f);
1734
            exceptionAtoms[i] = make_pair(atoms[i][0], atoms[i][1]);
1735
        }
1736
        exceptionParams->upload(exceptionParamsVector);
Peter Eastman's avatar
Peter Eastman committed
1737
1738
        map<string, string> replacements;
        replacements["PARAMS"] = cl.getBondedUtilities().addArgument(exceptionParams->getDeviceBuffer(), "float4");
1739
        cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::nonbondedExceptions, replacements), force.getForceGroup());
Peter Eastman's avatar
Peter Eastman committed
1740
1741
    }
    cl.addForce(new OpenCLNonbondedForceInfo(cl.getNonbondedUtilities().getNumForceBuffers(), force));
1742
1743
}

1744
double OpenCLCalcNonbondedForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy, bool includeDirect, bool includeReciprocal) {
1745
    bool deviceIsCpu = (cl.getDevice().getInfo<CL_DEVICE_TYPE>() == CL_DEVICE_TYPE_CPU);
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
    if (!hasInitializedKernel) {
        hasInitializedKernel = true;
        if (cosSinSums != NULL) {
            ewaldSumsKernel.setArg<cl::Buffer>(0, cl.getEnergyBuffer().getDeviceBuffer());
            ewaldSumsKernel.setArg<cl::Buffer>(1, cl.getPosq().getDeviceBuffer());
            ewaldSumsKernel.setArg<cl::Buffer>(2, cosSinSums->getDeviceBuffer());
            ewaldForcesKernel.setArg<cl::Buffer>(0, cl.getForceBuffers().getDeviceBuffer());
            ewaldForcesKernel.setArg<cl::Buffer>(1, cl.getPosq().getDeviceBuffer());
            ewaldForcesKernel.setArg<cl::Buffer>(2, cosSinSums->getDeviceBuffer());
        }
1756
        if (pmeGrid != NULL) {
1757
            cl::Program program = cl.createProgram(OpenCLKernelSources::pme, pmeDefines);
1758
            pmeUpdateBsplinesKernel = cl::Kernel(program, "updateBsplines");
1759
            pmeAtomRangeKernel = cl::Kernel(program, "findAtomRangeForGrid");
1760
            pmeZIndexKernel = cl::Kernel(program, "recordZIndex");
1761
1762
1763
            pmeSpreadChargeKernel = cl::Kernel(program, "gridSpreadCharge");
            pmeConvolutionKernel = cl::Kernel(program, "reciprocalConvolution");
            pmeInterpolateForceKernel = cl::Kernel(program, "gridInterpolateForce");
1764
            int elementSize = (cl.getUseDoublePrecision() ? sizeof(mm_double4) : sizeof(mm_float4));
1765
1766
            pmeUpdateBsplinesKernel.setArg<cl::Buffer>(0, cl.getPosq().getDeviceBuffer());
            pmeUpdateBsplinesKernel.setArg<cl::Buffer>(1, pmeBsplineTheta->getDeviceBuffer());
1767
            pmeUpdateBsplinesKernel.setArg(2, OpenCLContext::ThreadBlockSize*PmeOrder*elementSize, NULL);
1768
            pmeUpdateBsplinesKernel.setArg<cl::Buffer>(3, pmeAtomGridIndex->getDeviceBuffer());
1769
1770
1771
            pmeAtomRangeKernel.setArg<cl::Buffer>(0, pmeAtomGridIndex->getDeviceBuffer());
            pmeAtomRangeKernel.setArg<cl::Buffer>(1, pmeAtomRange->getDeviceBuffer());
            pmeAtomRangeKernel.setArg<cl::Buffer>(2, cl.getPosq().getDeviceBuffer());
1772
1773
            pmeZIndexKernel.setArg<cl::Buffer>(0, pmeAtomGridIndex->getDeviceBuffer());
            pmeZIndexKernel.setArg<cl::Buffer>(1, cl.getPosq().getDeviceBuffer());
1774
1775
1776
1777
1778
            pmeSpreadChargeKernel.setArg<cl::Buffer>(0, cl.getPosq().getDeviceBuffer());
            pmeSpreadChargeKernel.setArg<cl::Buffer>(1, pmeAtomGridIndex->getDeviceBuffer());
            pmeSpreadChargeKernel.setArg<cl::Buffer>(2, pmeAtomRange->getDeviceBuffer());
            pmeSpreadChargeKernel.setArg<cl::Buffer>(3, pmeGrid->getDeviceBuffer());
            pmeSpreadChargeKernel.setArg<cl::Buffer>(4, pmeBsplineTheta->getDeviceBuffer());
Peter Eastman's avatar
Peter Eastman committed
1779
            pmeConvolutionKernel.setArg<cl::Buffer>(0, pmeGrid2->getDeviceBuffer());
1780
1781
1782
1783
1784
1785
            pmeConvolutionKernel.setArg<cl::Buffer>(1, cl.getEnergyBuffer().getDeviceBuffer());
            pmeConvolutionKernel.setArg<cl::Buffer>(2, pmeBsplineModuliX->getDeviceBuffer());
            pmeConvolutionKernel.setArg<cl::Buffer>(3, pmeBsplineModuliY->getDeviceBuffer());
            pmeConvolutionKernel.setArg<cl::Buffer>(4, pmeBsplineModuliZ->getDeviceBuffer());
            pmeInterpolateForceKernel.setArg<cl::Buffer>(0, cl.getPosq().getDeviceBuffer());
            pmeInterpolateForceKernel.setArg<cl::Buffer>(1, cl.getForceBuffers().getDeviceBuffer());
1786
            pmeInterpolateForceKernel.setArg<cl::Buffer>(2, pmeGrid->getDeviceBuffer());
1787
            pmeInterpolateForceKernel.setArg<cl::Buffer>(7, pmeAtomGridIndex->getDeviceBuffer());
1788
1789
1790
1791
            if (cl.getSupports64BitGlobalAtomics()) {
                pmeFinishSpreadChargeKernel = cl::Kernel(program, "finishSpreadCharge");
                pmeFinishSpreadChargeKernel.setArg<cl::Buffer>(0, pmeGrid->getDeviceBuffer());
            }
1792
       }
1793
    }
1794
    if (cosSinSums != NULL && includeReciprocal) {
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
        mm_double4 boxSize = cl.getPeriodicBoxSizeDouble();
        mm_double4 recipBoxSize = mm_double4(2*M_PI/boxSize.x, 2*M_PI/boxSize.y, 2*M_PI/boxSize.z, 0.0);
        double recipCoefficient = ONE_4PI_EPS0*4*M_PI/(boxSize.x*boxSize.y*boxSize.z);
        if (cl.getUseDoublePrecision()) {
            ewaldSumsKernel.setArg<mm_double4>(3, recipBoxSize);
            ewaldSumsKernel.setArg<cl_double>(4, recipCoefficient);
            ewaldForcesKernel.setArg<mm_double4>(3, recipBoxSize);
            ewaldForcesKernel.setArg<cl_double>(4, recipCoefficient);
        }
        else {
            ewaldSumsKernel.setArg<mm_float4>(3, mm_float4((float) recipBoxSize.x, (float) recipBoxSize.y, (float) recipBoxSize.z, 0));
            ewaldSumsKernel.setArg<cl_float>(4, (cl_float) recipCoefficient);
            ewaldForcesKernel.setArg<mm_float4>(3, mm_float4((float) recipBoxSize.x, (float) recipBoxSize.y, (float) recipBoxSize.z, 0));
            ewaldForcesKernel.setArg<cl_float>(4, (cl_float) recipCoefficient);
        }
1810
1811
1812
        cl.executeKernel(ewaldSumsKernel, cosSinSums->getSize());
        cl.executeKernel(ewaldForcesKernel, cl.getNumAtoms());
    }
1813
    if (pmeGrid != NULL && includeReciprocal) {
1814
1815
        if (usePmeQueue)
            cl.setQueue(pmeQueue);
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
        
        // Invert the periodic box vectors.
        
        Vec3 boxVectors[3];
        cl.getPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
        double determinant = boxVectors[0][0]*boxVectors[1][1]*boxVectors[2][2];
        double scale = 1.0/determinant;
        mm_double4 recipBoxVectors[3];
        recipBoxVectors[0] = mm_double4(boxVectors[1][1]*boxVectors[2][2]*scale, 0, 0, 0);
        recipBoxVectors[1] = mm_double4(-boxVectors[1][0]*boxVectors[2][2]*scale, boxVectors[0][0]*boxVectors[2][2]*scale, 0, 0);
        recipBoxVectors[2] = mm_double4((boxVectors[1][0]*boxVectors[2][1]-boxVectors[1][1]*boxVectors[2][0])*scale, -boxVectors[0][0]*boxVectors[2][1]*scale, boxVectors[0][0]*boxVectors[1][1]*scale, 0);
        mm_float4 recipBoxVectorsFloat[3];
        for (int i = 0; i < 3; i++)
            recipBoxVectorsFloat[i] = mm_float4((float) recipBoxVectors[i].x, (float) recipBoxVectors[i].y, (float) recipBoxVectors[i].z, 0);
        
        // Execute the reciprocal space kernels.

1833
        setPeriodicBoxSizeArg(cl, pmeUpdateBsplinesKernel, 4);
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
        if (cl.getUseDoublePrecision()) {
            pmeUpdateBsplinesKernel.setArg<mm_double4>(5, recipBoxVectors[0]);
            pmeUpdateBsplinesKernel.setArg<mm_double4>(6, recipBoxVectors[1]);
            pmeUpdateBsplinesKernel.setArg<mm_double4>(7, recipBoxVectors[2]);
        }
        else {
            pmeUpdateBsplinesKernel.setArg<mm_float4>(5, recipBoxVectorsFloat[0]);
            pmeUpdateBsplinesKernel.setArg<mm_float4>(6, recipBoxVectorsFloat[1]);
            pmeUpdateBsplinesKernel.setArg<mm_float4>(7, recipBoxVectorsFloat[2]);
        }
1844
        cl.executeKernel(pmeUpdateBsplinesKernel, cl.getNumAtoms());
1845
        if (deviceIsCpu) {
1846
            setPeriodicBoxSizeArg(cl, pmeSpreadChargeKernel, 5);
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
            if (cl.getUseDoublePrecision()) {
                pmeSpreadChargeKernel.setArg<mm_double4>(6, recipBoxVectors[0]);
                pmeSpreadChargeKernel.setArg<mm_double4>(7, recipBoxVectors[1]);
                pmeSpreadChargeKernel.setArg<mm_double4>(8, recipBoxVectors[2]);
            }
            else {
                pmeSpreadChargeKernel.setArg<mm_float4>(6, recipBoxVectorsFloat[0]);
                pmeSpreadChargeKernel.setArg<mm_float4>(7, recipBoxVectorsFloat[1]);
                pmeSpreadChargeKernel.setArg<mm_float4>(8, recipBoxVectorsFloat[2]);
            }
1857
1858
1859
1860
            cl.executeKernel(pmeSpreadChargeKernel, 2*cl.getDevice().getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>(), 1);
        }
        else {
            sort->sort(*pmeAtomGridIndex);
1861
            if (cl.getSupports64BitGlobalAtomics()) {
1862
                setPeriodicBoxSizeArg(cl, pmeSpreadChargeKernel, 5);
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
                if (cl.getUseDoublePrecision()) {
                    pmeSpreadChargeKernel.setArg<mm_double4>(6, recipBoxVectors[0]);
                    pmeSpreadChargeKernel.setArg<mm_double4>(7, recipBoxVectors[1]);
                    pmeSpreadChargeKernel.setArg<mm_double4>(8, recipBoxVectors[2]);
                }
                else {
                    pmeSpreadChargeKernel.setArg<mm_float4>(6, recipBoxVectorsFloat[0]);
                    pmeSpreadChargeKernel.setArg<mm_float4>(7, recipBoxVectorsFloat[1]);
                    pmeSpreadChargeKernel.setArg<mm_float4>(8, recipBoxVectorsFloat[2]);
                }
1873
                cl.executeKernel(pmeSpreadChargeKernel, cl.getNumAtoms());
1874
1875
                cl.executeKernel(pmeFinishSpreadChargeKernel, pmeGrid->getSize());
            }
1876
            else {
1877
                cl.executeKernel(pmeAtomRangeKernel, cl.getNumAtoms());
1878
                setPeriodicBoxSizeArg(cl, pmeZIndexKernel, 2);
1879
1880
1881
1882
                if (cl.getUseDoublePrecision())
                    pmeZIndexKernel.setArg<mm_double4>(3, recipBoxVectors[2]);
                else
                    pmeZIndexKernel.setArg<mm_float4>(3, recipBoxVectorsFloat[2]);
1883
                cl.executeKernel(pmeZIndexKernel, cl.getNumAtoms());
1884
                cl.executeKernel(pmeSpreadChargeKernel, cl.getNumAtoms());
1885
            }
1886
        }
Peter Eastman's avatar
Peter Eastman committed
1887
        fft->execFFT(*pmeGrid, *pmeGrid2, true);
1888
1889
        mm_double4 boxSize = cl.getPeriodicBoxSizeDouble();
        double scaleFactor = 1.0/(M_PI*boxSize.x*boxSize.y*boxSize.z);
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
        if (cl.getUseDoublePrecision()) {
            pmeConvolutionKernel.setArg<mm_double4>(5, recipBoxVectors[0]);
            pmeConvolutionKernel.setArg<mm_double4>(6, recipBoxVectors[1]);
            pmeConvolutionKernel.setArg<mm_double4>(7, recipBoxVectors[2]);
            pmeConvolutionKernel.setArg<cl_double>(8, scaleFactor);
        }
        else {
            pmeConvolutionKernel.setArg<mm_float4>(5, recipBoxVectorsFloat[0]);
            pmeConvolutionKernel.setArg<mm_float4>(6, recipBoxVectorsFloat[1]);
            pmeConvolutionKernel.setArg<mm_float4>(7, recipBoxVectorsFloat[2]);
            pmeConvolutionKernel.setArg<cl_float>(8, (float) scaleFactor);
        }
1902
        cl.executeKernel(pmeConvolutionKernel, cl.getNumAtoms());
Peter Eastman's avatar
Peter Eastman committed
1903
        fft->execFFT(*pmeGrid2, *pmeGrid, false);
1904
        setPeriodicBoxSizeArg(cl, pmeInterpolateForceKernel, 3);
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
        if (cl.getUseDoublePrecision()) {
            pmeInterpolateForceKernel.setArg<mm_double4>(4, recipBoxVectors[0]);
            pmeInterpolateForceKernel.setArg<mm_double4>(5, recipBoxVectors[1]);
            pmeInterpolateForceKernel.setArg<mm_double4>(6, recipBoxVectors[2]);
        }
        else {
            pmeInterpolateForceKernel.setArg<mm_float4>(4, recipBoxVectorsFloat[0]);
            pmeInterpolateForceKernel.setArg<mm_float4>(5, recipBoxVectorsFloat[1]);
            pmeInterpolateForceKernel.setArg<mm_float4>(6, recipBoxVectorsFloat[2]);
        }
1915
1916
1917
1918
        if (deviceIsCpu)
            cl.executeKernel(pmeInterpolateForceKernel, 2*cl.getDevice().getInfo<CL_DEVICE_MAX_COMPUTE_UNITS>(), 1);
        else
            cl.executeKernel(pmeInterpolateForceKernel, cl.getNumAtoms());
1919
1920
1921
1922
        if (usePmeQueue) {
            pmeQueue.enqueueMarker(&pmeSyncEvent);
            cl.restoreDefaultQueue();
        }
1923
    }
1924
1925
    double energy = (includeReciprocal ? ewaldSelfEnergy : 0.0);
    if (dispersionCoefficient != 0.0 && includeDirect) {
1926
        mm_double4 boxSize = cl.getPeriodicBoxSizeDouble();
1927
1928
1929
        energy += dispersionCoefficient/(boxSize.x*boxSize.y*boxSize.z);
    }
    return energy;
1930
1931
}

1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
void OpenCLCalcNonbondedForceKernel::copyParametersToContext(ContextImpl& context, const NonbondedForce& force) {
    // Make sure the new parameters are acceptable.
    
    if (force.getNumParticles() != cl.getNumAtoms())
        throw OpenMMException("updateParametersInContext: The number of particles has changed");
    if (!hasCoulomb || !hasLJ) {
        for (int i = 0; i < force.getNumParticles(); i++) {
            double charge, sigma, epsilon;
            force.getParticleParameters(i, charge, sigma, epsilon);
            if (!hasCoulomb && charge != 0.0)
                throw OpenMMException("updateParametersInContext: The nonbonded force kernel does not include Coulomb interactions, because all charges were originally 0");
            if (!hasLJ && epsilon != 0.0)
                throw OpenMMException("updateParametersInContext: The nonbonded force kernel does not include Lennard-Jones interactions, because all epsilons were originally 0");
        }
    }
    vector<int> exceptions;
    for (int i = 0; i < force.getNumExceptions(); i++) {
        int particle1, particle2;
        double chargeProd, sigma, epsilon;
        force.getExceptionParameters(i, particle1, particle2, chargeProd, sigma, epsilon);
1952
        if (exceptionAtoms.size() > exceptions.size() && make_pair(particle1, particle2) == exceptionAtoms[exceptions.size()])
1953
            exceptions.push_back(i);
1954
1955
        else if (chargeProd != 0.0 || epsilon != 0.0)
            throw OpenMMException("updateParametersInContext: The set of non-excluded exceptions has changed");
1956
1957
1958
1959
1960
1961
1962
1963
    }
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*exceptions.size()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*exceptions.size()/numContexts;
    int numExceptions = endIndex-startIndex;
    
    // Record the per-particle parameters.
    
1964
    OpenCLArray& posq = cl.getPosq();
1965
    posq.download(cl.getPinnedBuffer());
1966
    mm_float4* posqf = (mm_float4*) cl.getPinnedBuffer();
1967
    mm_double4* posqd = (mm_double4*) cl.getPinnedBuffer();
1968
    vector<mm_float2> sigmaEpsilonVector(cl.getPaddedNumAtoms(), mm_float2(0,0));
1969
    double sumSquaredCharges = 0.0;
1970
    const vector<cl_int>& order = cl.getAtomIndex();
1971
1972
1973
1974
    for (int i = 0; i < force.getNumParticles(); i++) {
        int index = order[i];
        double charge, sigma, epsilon;
        force.getParticleParameters(index, charge, sigma, epsilon);
1975
1976
1977
1978
        if (cl.getUseDoublePrecision())
            posqd[i].w = charge;
        else
            posqf[i].w = (float) charge;
1979
1980
1981
        sigmaEpsilonVector[index] = mm_float2((float) (0.5*sigma), (float) (2.0*sqrt(epsilon)));
        sumSquaredCharges += charge*charge;
    }
1982
    posq.upload(cl.getPinnedBuffer());
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
    sigmaEpsilon->upload(sigmaEpsilonVector);
    
    // Record the exceptions.
    
    if (numExceptions > 0) {
        vector<vector<int> > atoms(numExceptions, vector<int>(2));
        vector<mm_float4> exceptionParamsVector(numExceptions);
        for (int i = 0; i < numExceptions; i++) {
            double chargeProd, sigma, epsilon;
            force.getExceptionParameters(exceptions[startIndex+i], atoms[i][0], atoms[i][1], chargeProd, sigma, epsilon);
            exceptionParamsVector[i] = mm_float4((float) (ONE_4PI_EPS0*chargeProd), (float) sigma, (float) (4.0*epsilon), 0.0f);
        }
        exceptionParams->upload(exceptionParamsVector);
    }
    
    // Compute other values.
    
    NonbondedForce::NonbondedMethod method = force.getNonbondedMethod();
    if (method == NonbondedForce::Ewald || method == NonbondedForce::PME)
        ewaldSelfEnergy = (cl.getContextIndex() == 0 ? -ONE_4PI_EPS0*alpha*sumSquaredCharges/sqrt(M_PI) : 0.0);
    if (force.getUseDispersionCorrection() && cl.getContextIndex() == 0 && (method == NonbondedForce::CutoffPeriodic || method == NonbondedForce::Ewald || method == NonbondedForce::PME))
        dispersionCoefficient = NonbondedForceImpl::calcDispersionCorrection(context.getSystem(), force);
    cl.invalidateMolecules();
}

2008
2009
2010
class OpenCLCustomNonbondedForceInfo : public OpenCLForceInfo {
public:
    OpenCLCustomNonbondedForceInfo(int requiredBuffers, const CustomNonbondedForce& force) : OpenCLForceInfo(requiredBuffers), force(force) {
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
        if (force.getNumInteractionGroups() > 0) {
            groupsForParticle.resize(force.getNumParticles());
            for (int i = 0; i < force.getNumInteractionGroups(); i++) {
                set<int> set1, set2;
                force.getInteractionGroupParameters(i, set1, set2);
                for (set<int>::const_iterator iter = set1.begin(); iter != set1.end(); ++iter)
                    groupsForParticle[*iter].insert(2*i);
                for (set<int>::const_iterator iter = set2.begin(); iter != set2.end(); ++iter)
                    groupsForParticle[*iter].insert(2*i+1);
            }
        }
2022
2023
2024
2025
2026
2027
    }
    bool areParticlesIdentical(int particle1, int particle2) {
        vector<double> params1;
        vector<double> params2;
        force.getParticleParameters(particle1, params1);
        force.getParticleParameters(particle2, params2);
2028
        for (int i = 0; i < (int) params1.size(); i++)
2029
2030
            if (params1[i] != params2[i])
                return false;
2031
2032
        if (groupsForParticle.size() > 0 && groupsForParticle[particle1] != groupsForParticle[particle2])
            return false;
2033
2034
2035
        return true;
    }
    int getNumParticleGroups() {
2036
        return force.getNumExclusions();
2037
    }
Peter Eastman's avatar
Peter Eastman committed
2038
    void getParticlesInGroup(int index, vector<int>& particles) {
2039
        int particle1, particle2;
2040
        force.getExclusionParticles(index, particle1, particle2);
2041
2042
2043
2044
2045
2046
2047
2048
2049
        particles.resize(2);
        particles[0] = particle1;
        particles[1] = particle2;
    }
    bool areGroupsIdentical(int group1, int group2) {
        return true;
    }
private:
    const CustomNonbondedForce& force;
2050
    vector<set<int> > groupsForParticle;
2051
2052
2053
2054
2055
2056
2057
};

OpenCLCalcCustomNonbondedForceKernel::~OpenCLCalcCustomNonbondedForceKernel() {
    if (params != NULL)
        delete params;
    if (globals != NULL)
        delete globals;
2058
2059
    if (interactionGroupData != NULL)
        delete interactionGroupData;
2060
2061
    for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
        delete tabulatedFunctions[i];
2062
2063
    if (forceCopy != NULL)
        delete forceCopy;
2064
2065
2066
2067
2068
2069
}

void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, const CustomNonbondedForce& force) {
    int forceIndex;
    for (forceIndex = 0; forceIndex < system.getNumForces() && &system.getForce(forceIndex) != &force; ++forceIndex)
        ;
2070
    string prefix = (force.getNumInteractionGroups() == 0 ? "custom"+cl.intToString(forceIndex)+"_" : "");
2071
2072
2073
2074

    // Record parameters and exclusions.

    int numParticles = force.getNumParticles();
2075
    params = new OpenCLParameterSet(cl, force.getNumPerParticleParameters(), numParticles, "customNonbondedParameters");
2076
    if (force.getNumGlobalParameters() > 0)
2077
        globals = OpenCLArray::create<cl_float>(cl, force.getNumGlobalParameters(), "customNonbondedGlobals", CL_MEM_READ_ONLY);
2078
    vector<vector<cl_float> > paramVector(numParticles);
2079
2080
2081
2082
    vector<vector<int> > exclusionList(numParticles);
    for (int i = 0; i < numParticles; i++) {
        vector<double> parameters;
        force.getParticleParameters(i, parameters);
2083
        paramVector[i].resize(parameters.size());
2084
        for (int j = 0; j < (int) parameters.size(); j++)
2085
            paramVector[i][j] = (cl_float) parameters[j];
2086
2087
        exclusionList[i].push_back(i);
    }
2088
2089
2090
2091
2092
    for (int i = 0; i < force.getNumExclusions(); i++) {
        int particle1, particle2;
        force.getExclusionParticles(i, particle1, particle2);
        exclusionList[particle1].push_back(particle2);
        exclusionList[particle2].push_back(particle1);
2093
    }
2094
    params->setParameterValues(paramVector);
2095
2096
2097

    // Record the tabulated functions.

2098
2099
    map<string, Lepton::CustomFunction*> functions;
    vector<pair<string, string> > functionDefinitions;
2100
    vector<const TabulatedFunction*> functionList;
2101
    for (int i = 0; i < force.getNumFunctions(); i++) {
2102
2103
        functionList.push_back(&force.getTabulatedFunction(i));
        string name = force.getTabulatedFunctionName(i);
2104
        string arrayName = prefix+"table"+cl.intToString(i);
2105
        functionDefinitions.push_back(make_pair(name, arrayName));
2106
        functions[name] = cl.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
peastman's avatar
peastman committed
2107
        int width;
2108
        vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
2109
        tabulatedFunctions.push_back(OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction"));
2110
        tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
peastman's avatar
peastman committed
2111
        cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(arrayName, "float", width, width*sizeof(float), tabulatedFunctions[tabulatedFunctions.size()-1]->getDeviceBuffer()));
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
    }

    // Record information for the expressions.

    globalParamNames.resize(force.getNumGlobalParameters());
    globalParamValues.resize(force.getNumGlobalParameters());
    for (int i = 0; i < force.getNumGlobalParameters(); i++) {
        globalParamNames[i] = force.getGlobalParameterName(i);
        globalParamValues[i] = (cl_float) force.getGlobalParameterDefaultValue(i);
    }
    if (globals != NULL)
        globals->upload(globalParamValues);
    bool useCutoff = (force.getNonbondedMethod() != CustomNonbondedForce::NoCutoff);
    bool usePeriodic = (force.getNonbondedMethod() != CustomNonbondedForce::NoCutoff && force.getNonbondedMethod() != CustomNonbondedForce::CutoffNonPeriodic);
2126
    Lepton::ParsedExpression energyExpression = Lepton::Parser::parse(force.getEnergyFunction(), functions).optimize();
2127
    Lepton::ParsedExpression forceExpression = energyExpression.differentiate("r").optimize();
2128
2129
2130
    map<string, Lepton::ParsedExpression> forceExpressions;
    forceExpressions["tempEnergy += "] = energyExpression;
    forceExpressions["tempForce -= "] = forceExpression;
2131
2132
2133

    // Create the kernels.

2134
2135
2136
2137
2138
    vector<pair<ExpressionTreeNode, string> > variables;
    ExpressionTreeNode rnode(new Operation::Variable("r"));
    variables.push_back(make_pair(rnode, "r"));
    variables.push_back(make_pair(ExpressionTreeNode(new Operation::Square(), rnode), "r2"));
    variables.push_back(make_pair(ExpressionTreeNode(new Operation::Reciprocal(), rnode), "invR"));
2139
2140
    for (int i = 0; i < force.getNumPerParticleParameters(); i++) {
        const string& name = force.getPerParticleParameterName(i);
2141
2142
        variables.push_back(makeVariable(name+"1", prefix+"params"+params->getParameterSuffix(i, "1")));
        variables.push_back(makeVariable(name+"2", prefix+"params"+params->getParameterSuffix(i, "2")));
2143
2144
2145
    }
    for (int i = 0; i < force.getNumGlobalParameters(); i++) {
        const string& name = force.getGlobalParameterName(i);
2146
        string value = "globals["+cl.intToString(i)+"]";
2147
        variables.push_back(makeVariable(name, prefix+value));
2148
    }
2149
    stringstream compute;
2150
    compute << cl.getExpressionUtilities().createExpressions(forceExpressions, variables, functionList, functionDefinitions, prefix+"temp");
2151
2152
    map<string, string> replacements;
    replacements["COMPUTE_FORCE"] = compute.str();
2153
2154
2155
2156
2157
2158
2159
2160
2161
    replacements["USE_SWITCH"] = (useCutoff && force.getUseSwitchingFunction() ? "1" : "0");
    if (force.getUseSwitchingFunction()) {
        // Compute the switching coefficients.
        
        replacements["SWITCH_CUTOFF"] = cl.doubleToString(force.getSwitchingDistance());
        replacements["SWITCH_C3"] = cl.doubleToString(10/pow(force.getSwitchingDistance()-force.getCutoffDistance(), 3.0));
        replacements["SWITCH_C4"] = cl.doubleToString(15/pow(force.getSwitchingDistance()-force.getCutoffDistance(), 4.0));
        replacements["SWITCH_C5"] = cl.doubleToString(6/pow(force.getSwitchingDistance()-force.getCutoffDistance(), 5.0));
    }
2162
    string source = cl.replaceStrings(OpenCLKernelSources::customNonbonded, replacements);
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
    if (force.getNumInteractionGroups() > 0)
        initInteractionGroups(force, source);
    else {
        cl.getNonbondedUtilities().addInteraction(useCutoff, usePeriodic, true, force.getCutoffDistance(), exclusionList, source, force.getForceGroup());
        for (int i = 0; i < (int) params->getBuffers().size(); i++) {
            const OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i];
            cl.getNonbondedUtilities().addParameter(OpenCLNonbondedUtilities::ParameterInfo(prefix+"params"+cl.intToString(i+1), buffer.getComponentType(), buffer.getNumComponents(), buffer.getSize(), buffer.getMemory()));
        }
        if (globals != NULL) {
            globals->upload(globalParamValues);
            cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(prefix+"globals", "float", 1, sizeof(cl_float), globals->getDeviceBuffer()));
        }
2175
    }
2176
    cl.addForce(new OpenCLCustomNonbondedForceInfo(cl.getNonbondedUtilities().getNumForceBuffers(), force));
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
    
    // Record information for the long range correction.
    
    if (force.getNonbondedMethod() == CustomNonbondedForce::CutoffPeriodic && force.getUseLongRangeCorrection() && cl.getContextIndex() == 0) {
        forceCopy = new CustomNonbondedForce(force);
        hasInitializedLongRangeCorrection = false;
    }
    else {
        longRangeCoefficient = 0.0;
        hasInitializedLongRangeCorrection = true;
    }
2188
2189
}

2190
2191
2192
2193
2194
void OpenCLCalcCustomNonbondedForceKernel::initInteractionGroups(const CustomNonbondedForce& force, const string& interactionSource) {
    // Process groups to form tiles.
    
    vector<vector<int> > atomLists;
    vector<pair<int, int> > tiles;
peastman's avatar
peastman committed
2195
    map<pair<int, int>, int> duplicateInteractions;
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
    for (int group = 0; group < force.getNumInteractionGroups(); group++) {
        // Get the list of atoms in this group and sort them.
        
        set<int> set1, set2;
        force.getInteractionGroupParameters(group, set1, set2);
        vector<int> atoms1, atoms2;
        atoms1.insert(atoms1.begin(), set1.begin(), set1.end());
        atoms2.insert(atoms2.begin(), set2.begin(), set2.end());
        sort(atoms1.begin(), atoms1.end());
        sort(atoms2.begin(), atoms2.end());
        
        // Find how many tiles we will create for this group.
        
        int tileWidth = min(min(32, (int) atoms1.size()), (int) atoms2.size());
2210
2211
        if (tileWidth == 0)
            continue;
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
        int numBlocks1 = (atoms1.size()+tileWidth-1)/tileWidth;
        int numBlocks2 = (atoms2.size()+tileWidth-1)/tileWidth;
        
        // Add the tiles.
        
        for (int i = 0; i < numBlocks1; i++)
            for (int j = 0; j < numBlocks2; j++)
                tiles.push_back(make_pair(atomLists.size()+i, atomLists.size()+numBlocks1+j));
        
        // Add the atom lists.
        
        for (int i = 0; i < numBlocks1; i++) {
            vector<int> atoms;
            int first = i*tileWidth;
            int last = min((i+1)*tileWidth, (int) atoms1.size());
            for (int j = first; j < last; j++)
                atoms.push_back(atoms1[j]);
            atomLists.push_back(atoms);
        }
        for (int i = 0; i < numBlocks2; i++) {
            vector<int> atoms;
            int first = i*tileWidth;
            int last = min((i+1)*tileWidth, (int) atoms2.size());
            for (int j = first; j < last; j++)
                atoms.push_back(atoms2[j]);
            atomLists.push_back(atoms);
        }
peastman's avatar
peastman committed
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
        
        // If this group contains duplicate interactions, record that we need to skip them once.
        
        for (int i = 0; i < (int) atoms1.size(); i++) {
            int a1 = atoms1[i];
            if (set2.find(a1) == set2.end())
                continue;
            for (int j = 0; j < (int) atoms2.size() && atoms2[j] < a1; j++) {
                int a2 = atoms2[j];
                if (set1.find(a2) != set1.end()) {
                    pair<int, int> key = make_pair(a2, a1);
                    if (duplicateInteractions.find(key) == duplicateInteractions.end())
                        duplicateInteractions[key] = 0;
                    duplicateInteractions[key]++;
                }
            }
        }
2256
2257
2258
2259
2260
2261
2262
2263
    }
    
    // Build a lookup table for quickly identifying excluded interactions.
    
    set<pair<int, int> > exclusions;
    for (int i = 0; i < force.getNumExclusions(); i++) {
        int p1, p2;
        force.getExclusionParticles(i, p1, p2);
peastman's avatar
peastman committed
2264
        exclusions.insert(make_pair(min(p1, p2), max(p1, p2)));
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
    }
    
    // Build the exclusion flags for each tile.  While we're at it, filter out tiles
    // where all interactions are excluded, and sort the tiles by size.

    vector<vector<int> > exclusionFlags(tiles.size());
    vector<pair<int, int> > tileOrder;
    for (int tile = 0; tile < tiles.size(); tile++) {
        if (atomLists[tiles[tile].first].size() < atomLists[tiles[tile].second].size()) {
            // For efficiency, we want the first axis to be the larger one.
            
            int swap = tiles[tile].first;
            tiles[tile].first = tiles[tile].second;
            tiles[tile].second = swap;
        }
        vector<int>& atoms1 = atomLists[tiles[tile].first];
        vector<int>& atoms2 = atomLists[tiles[tile].second];
peastman's avatar
peastman committed
2282
        vector<int> flags(atoms1.size(), (int) (1LL<<atoms2.size())-1);
2283
2284
2285
2286
2287
        int numExcluded = 0;
        for (int i = 0; i < (int) atoms1.size(); i++)
            for (int j = 0; j < (int) atoms2.size(); j++) {
                int a1 = atoms1[i];
                int a2 = atoms2[j];
peastman's avatar
peastman committed
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
                bool isExcluded = false;
                pair<int, int> key = make_pair(min(a1, a2), max(a1, a2));
                if (a1 == a2 || exclusions.find(key) != exclusions.end())
                    isExcluded = true; // This is an excluded interaction.
                else if (duplicateInteractions.find(key) != duplicateInteractions.end() && duplicateInteractions[key] > 0) {
                    // Both atoms are in both sets, so skip duplicate interactions.
                    
                    isExcluded = true;
                    duplicateInteractions[key]--;
                }
                if (isExcluded) {
2299
2300
2301
2302
2303
2304
2305
                    flags[i] &= -1-(1<<j);
                    numExcluded++;
                }
            }
        if (numExcluded == atoms1.size()*atoms2.size())
            continue; // All interactions are excluded.
        tileOrder.push_back(make_pair((int) -atoms2.size(), tile));
peastman's avatar
peastman committed
2306
        exclusionFlags[tile] = flags;
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
    }
    sort(tileOrder.begin(), tileOrder.end());
    
    // Merge tiles to get as close as possible to 32 along the first axis of each one.
    
    vector<int> tileSetStart;
    tileSetStart.push_back(0);
    int tileSetSize = 0;
    for (int i = 0; i < tileOrder.size(); i++) {
        int tile = tileOrder[i].second;
        int size = atomLists[tiles[tile].first].size();
        if (tileSetSize+size > 32) {
            tileSetStart.push_back(i);
            tileSetSize = 0;
        }
        tileSetSize += size;
    }
    tileSetStart.push_back(tileOrder.size());
    
    // Build the data structures.
    
    int numTileSets = tileSetStart.size()-1;
    vector<mm_int4> groupData;
    for (int tileSet = 0; tileSet < numTileSets; tileSet++) {
        int indexInTileSet = 0;
2332
2333
2334
2335
2336
2337
2338
2339
        int minSize = 0;
        if (cl.getSIMDWidth() < 32) {
            // We need to include a barrier inside the inner loop, so ensure that all
            // threads will loop the same number of times.
            
            for (int i = tileSetStart[tileSet]; i < tileSetStart[tileSet+1]; i++)
                minSize = max(minSize, (int) atomLists[tiles[tileOrder[i].second].first].size());
        }
2340
2341
2342
2343
        for (int i = tileSetStart[tileSet]; i < tileSetStart[tileSet+1]; i++) {
            int tile = tileOrder[i].second;
            vector<int>& atoms1 = atomLists[tiles[tile].first];
            vector<int>& atoms2 = atomLists[tiles[tile].second];
2344
            int range = indexInTileSet + ((indexInTileSet+max(minSize, (int) atoms1.size()))<<16);
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
            int allFlags = (1<<atoms2.size())-1;
            for (int j = 0; j < (int) atoms1.size(); j++) {
                int a1 = atoms1[j];
                int a2 = (j < atoms2.size() ? atoms2[j] : 0);
                int flags = (exclusionFlags[tile].size() > 0 ? exclusionFlags[tile][j] : allFlags);
                groupData.push_back(mm_int4(a1, a2, range, flags<<indexInTileSet));
            }
            indexInTileSet += atoms1.size();
        }
        for (; indexInTileSet < 32; indexInTileSet++)
2355
            groupData.push_back(mm_int4(0, 0, minSize<<16, 0));
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
    }
    interactionGroupData = OpenCLArray::create<mm_int4>(cl, groupData.size(), "interactionGroupData");
    interactionGroupData->upload(groupData);
    
    // Create the kernel.
    
    map<string, string> replacements;
    replacements["COMPUTE_INTERACTION"] = interactionSource;
    const string suffixes[] = {"x", "y", "z", "w"};
    stringstream localData;
    int localDataSize = 0;
    vector<OpenCLNonbondedUtilities::ParameterInfo>& buffers = params->getBuffers(); 
    for (int i = 0; i < (int) buffers.size(); i++) {
        if (buffers[i].getNumComponents() == 1)
            localData<<buffers[i].getComponentType()<<" params"<<(i+1)<<";\n";
        else {
            for (int j = 0; j < buffers[i].getNumComponents(); ++j)
                localData<<buffers[i].getComponentType()<<" params"<<(i+1)<<"_"<<suffixes[j]<<";\n";
        }
        localDataSize += buffers[i].getSize();
    }
    replacements["ATOM_PARAMETER_DATA"] = localData.str();
    stringstream args;
    for (int i = 0; i < (int) buffers.size(); i++)
        args<<", __global const "<<buffers[i].getType()<<"* restrict global_params"<<(i+1);
    if (globals != NULL)
        args<<", __global const float* restrict globals";
    replacements["PARAMETER_ARGUMENTS"] = args.str();
    stringstream load1;
    for (int i = 0; i < (int) buffers.size(); i++)
        load1<<buffers[i].getType()<<" params"<<(i+1)<<"1 = global_params"<<(i+1)<<"[atom1];\n";
    replacements["LOAD_ATOM1_PARAMETERS"] = load1.str();
    stringstream loadLocal2;
    for (int i = 0; i < (int) buffers.size(); i++) {
        if (buffers[i].getNumComponents() == 1)
            loadLocal2<<"localData[get_local_id(0)].params"<<(i+1)<<" = global_params"<<(i+1)<<"[atom2];\n";
        else {
            loadLocal2<<buffers[i].getType()<<" temp_params"<<(i+1)<<" = global_params"<<(i+1)<<"[atom2];\n";
            for (int j = 0; j < buffers[i].getNumComponents(); ++j)
                loadLocal2<<"localData[get_local_id(0)].params"<<(i+1)<<"_"<<suffixes[j]<<" = temp_params"<<(i+1)<<"."<<suffixes[j]<<";\n";
        }
    }
    replacements["LOAD_LOCAL_PARAMETERS"] = loadLocal2.str();
    stringstream load2;
    for (int i = 0; i < (int) buffers.size(); i++) {
        if (buffers[i].getNumComponents() == 1)
            load2<<buffers[i].getType()<<" params"<<(i+1)<<"2 = localData[localIndex].params"<<(i+1)<<";\n";
        else {
2404
            load2<<buffers[i].getType()<<" params"<<(i+1)<<"2 = ("<<buffers[i].getType()<<") (";
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
            for (int j = 0; j < buffers[i].getNumComponents(); ++j) {
                if (j > 0)
                    load2<<", ";
                load2<<"localData[localIndex].params"<<(i+1)<<"_"<<suffixes[j];
            }
            load2<<");\n";
        }
    }
    replacements["LOAD_ATOM2_PARAMETERS"] = load2.str();
    map<string, string> defines;
    if (force.getNonbondedMethod() != CustomNonbondedForce::NoCutoff)
        defines["USE_CUTOFF"] = "1";
    if (force.getNonbondedMethod() == CustomNonbondedForce::CutoffPeriodic)
        defines["USE_PERIODIC"] = "1";
2419
    defines["LOCAL_MEMORY_SIZE"] = cl.intToString(max(32, cl.getNonbondedUtilities().getForceThreadBlockSize()));
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
    double cutoff = force.getCutoffDistance();
    defines["CUTOFF_SQUARED"] = cl.doubleToString(cutoff*cutoff);
    defines["PADDED_NUM_ATOMS"] = cl.intToString(cl.getPaddedNumAtoms());
    defines["TILE_SIZE"] = "32";
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*numTileSets/numContexts;
    int endIndex = (cl.getContextIndex()+1)*numTileSets/numContexts;
    defines["FIRST_TILE"] = cl.intToString(startIndex);
    defines["LAST_TILE"] = cl.intToString(endIndex);
    if ((localDataSize/4)%2 == 0 && !cl.getUseDoublePrecision())
        defines["PARAMETER_SIZE_IS_EVEN"] = "1";
    cl::Program program = cl.createProgram(cl.replaceStrings(OpenCLKernelSources::customNonbondedGroups, replacements), defines);
    interactionGroupKernel = cl::Kernel(program, "computeInteractionGroups");
    numGroupThreadBlocks = cl.getNonbondedUtilities().getNumForceThreadBlocks();
}

2436
double OpenCLCalcCustomNonbondedForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
2437
2438
    if (globals != NULL) {
        bool changed = false;
2439
        for (int i = 0; i < (int) globalParamNames.size(); i++) {
2440
2441
2442
2443
2444
            cl_float value = (cl_float) context.getParameter(globalParamNames[i]);
            if (value != globalParamValues[i])
                changed = true;
            globalParamValues[i] = value;
        }
2445
        if (changed) {
2446
            globals->upload(globalParamValues);
2447
2448
2449
2450
2451
            if (forceCopy != NULL) {
                longRangeCoefficient = CustomNonbondedForceImpl::calcLongRangeCorrection(*forceCopy, context.getOwner());
                hasInitializedLongRangeCorrection = true;
            }
        }
2452
    }
2453
2454
2455
2456
    if (!hasInitializedLongRangeCorrection) {
        longRangeCoefficient = CustomNonbondedForceImpl::calcLongRangeCorrection(*forceCopy, context.getOwner());
        hasInitializedLongRangeCorrection = true;
    }
2457
2458
2459
2460
    if (interactionGroupData != NULL) {
        if (!hasInitializedKernel) {
            hasInitializedKernel = true;
            int index = 0;
2461
2462
            bool useLong = cl.getSupports64BitGlobalAtomics();
            interactionGroupKernel.setArg<cl::Buffer>(index++, (useLong ? cl.getLongForceBuffer() : cl.getForceBuffers()).getDeviceBuffer());
2463
2464
2465
            interactionGroupKernel.setArg<cl::Buffer>(index++, cl.getEnergyBuffer().getDeviceBuffer());
            interactionGroupKernel.setArg<cl::Buffer>(index++, cl.getPosq().getDeviceBuffer());
            interactionGroupKernel.setArg<cl::Buffer>(index++, interactionGroupData->getDeviceBuffer());
2466
2467
            setPeriodicBoxArgs(cl, interactionGroupKernel, index);
            index += 5;
2468
2469
2470
2471
2472
            for (int i = 0; i < (int) params->getBuffers().size(); i++)
                interactionGroupKernel.setArg<cl::Memory>(index++, params->getBuffers()[i].getMemory());
            if (globals != NULL)
                interactionGroupKernel.setArg<cl::Buffer>(index++, globals->getDeviceBuffer());
        }
2473
        int forceThreadBlockSize = max(32, cl.getNonbondedUtilities().getForceThreadBlockSize());
2474
2475
        cl.executeKernel(interactionGroupKernel, numGroupThreadBlocks*forceThreadBlockSize, forceThreadBlockSize);
    }
2476
2477
    mm_double4 boxSize = cl.getPeriodicBoxSizeDouble();
    return longRangeCoefficient/(boxSize.x*boxSize.y*boxSize.z);
2478
}
Peter Eastman's avatar
Peter Eastman committed
2479

2480
2481
2482
2483
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496
void OpenCLCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImpl& context, const CustomNonbondedForce& force) {
    int numParticles = force.getNumParticles();
    if (numParticles != cl.getNumAtoms())
        throw OpenMMException("updateParametersInContext: The number of particles has changed");
    
    // Record the per-particle parameters.
    
    vector<vector<cl_float> > paramVector(numParticles);
    vector<double> parameters;
    for (int i = 0; i < numParticles; i++) {
        force.getParticleParameters(i, parameters);
        paramVector[i].resize(parameters.size());
        for (int j = 0; j < (int) parameters.size(); j++)
            paramVector[i][j] = (cl_float) parameters[j];
    }
    params->setParameterValues(paramVector);
    
2497
2498
2499
2500
2501
2502
2503
2504
    // If necessary, recompute the long range correction.
    
    if (forceCopy != NULL) {
        longRangeCoefficient = CustomNonbondedForceImpl::calcLongRangeCorrection(force, context.getOwner());
        hasInitializedLongRangeCorrection = true;
        *forceCopy = force;
    }
    
2505
2506
2507
2508
2509
    // Mark that the current reordering may be invalid.
    
    cl.invalidateMolecules();
}

Peter Eastman's avatar
Peter Eastman committed
2510
2511
2512
2513
2514
2515
2516
2517
2518
2519
2520
2521
2522
2523
class OpenCLGBSAOBCForceInfo : public OpenCLForceInfo {
public:
    OpenCLGBSAOBCForceInfo(int requiredBuffers, const GBSAOBCForce& force) : OpenCLForceInfo(requiredBuffers), force(force) {
    }
    bool areParticlesIdentical(int particle1, int particle2) {
        double charge1, charge2, radius1, radius2, scale1, scale2;
        force.getParticleParameters(particle1, charge1, radius1, scale1);
        force.getParticleParameters(particle2, charge2, radius2, scale2);
        return (charge1 == charge2 && radius1 == radius2 && scale1 == scale2);
    }
private:
    const GBSAOBCForce& force;
};

2524
2525
2526
2527
2528
OpenCLCalcGBSAOBCForceKernel::~OpenCLCalcGBSAOBCForceKernel() {
    if (params != NULL)
        delete params;
    if (bornSum != NULL)
        delete bornSum;
2529
2530
    if (longBornSum != NULL)
        delete longBornSum;
2531
2532
2533
2534
    if (bornRadii != NULL)
        delete bornRadii;
    if (bornForce != NULL)
        delete bornForce;
2535
2536
    if (longBornForce != NULL)
        delete longBornForce;
2537
2538
2539
2540
2541
    if (obcChain != NULL)
        delete obcChain;
}

void OpenCLCalcGBSAOBCForceKernel::initialize(const System& system, const GBSAOBCForce& force) {
2542
2543
    if (cl.getPlatformData().contexts.size() > 1)
        throw OpenMMException("GBSAOBCForce does not support using multiple OpenCL devices");
2544
    OpenCLNonbondedUtilities& nb = cl.getNonbondedUtilities();
2545
    params = OpenCLArray::create<mm_float2>(cl, cl.getPaddedNumAtoms(), "gbsaObcParams");
2546
2547
2548
    int elementSize = (cl.getUseDoublePrecision() ? sizeof(cl_double) : sizeof(cl_float));
    bornRadii = new OpenCLArray(cl, cl.getPaddedNumAtoms(), elementSize, "bornRadii");
    obcChain = new OpenCLArray(cl, cl.getPaddedNumAtoms(), elementSize, "obcChain");
2549
    if (cl.getSupports64BitGlobalAtomics()) {
2550
2551
        longBornSum = OpenCLArray::create<cl_long>(cl, cl.getPaddedNumAtoms(), "longBornSum");
        longBornForce = OpenCLArray::create<cl_long>(cl, cl.getPaddedNumAtoms(), "longBornForce");
2552
        bornForce = new OpenCLArray(cl, cl.getPaddedNumAtoms(), elementSize, "bornForce");
2553
2554
        cl.addAutoclearBuffer(*longBornSum);
        cl.addAutoclearBuffer(*longBornForce);
2555
2556
    }
    else {
2557
2558
        bornSum = new OpenCLArray(cl, cl.getPaddedNumAtoms()*nb.getNumForceBuffers(), elementSize, "bornSum");
        bornForce = new OpenCLArray(cl, cl.getPaddedNumAtoms()*nb.getNumForceBuffers(), elementSize, "bornForce");
2559
2560
        cl.addAutoclearBuffer(*bornSum);
        cl.addAutoclearBuffer(*bornForce);
2561
    }
2562
2563
    vector<mm_float4> posqf(cl.getPaddedNumAtoms());
    vector<mm_double4> posqd(cl.getPaddedNumAtoms());
2564
    vector<mm_float2> paramsVector(cl.getPaddedNumAtoms(), mm_float2(1,1));
2565
    const double dielectricOffset = 0.009;
2566
    for (int i = 0; i < force.getNumParticles(); i++) {
2567
2568
2569
        double charge, radius, scalingFactor;
        force.getParticleParameters(i, charge, radius, scalingFactor);
        radius -= dielectricOffset;
2570
        paramsVector[i] = mm_float2((float) radius, (float) (scalingFactor*radius));
2571
2572
2573
2574
        if (cl.getUseDoublePrecision())
            posqd[i] = mm_double4(0, 0, 0, charge);
        else
            posqf[i] = mm_float4(0, 0, 0, (float) charge);
2575
    }
2576
2577
2578
2579
    if (cl.getUseDoublePrecision())
        cl.getPosq().upload(posqd);
    else
        cl.getPosq().upload(posqf);
2580
    params->upload(paramsVector);
2581
    prefactor = -ONE_4PI_EPS0*((1.0/force.getSoluteDielectric())-(1.0/force.getSolventDielectric()));
2582
    surfaceAreaFactor = -6.0*4*M_PI*force.getSurfaceAreaEnergy();
2583
2584
    bool useCutoff = (force.getNonbondedMethod() != GBSAOBCForce::NoCutoff);
    bool usePeriodic = (force.getNonbondedMethod() != GBSAOBCForce::NoCutoff && force.getNonbondedMethod() != GBSAOBCForce::CutoffNonPeriodic);
2585
    string source = OpenCLKernelSources::gbsaObc2;
2586
    nb.addInteraction(useCutoff, usePeriodic, false, force.getCutoffDistance(), vector<vector<int> >(), source, force.getForceGroup());
2587
    nb.addParameter(OpenCLNonbondedUtilities::ParameterInfo("obcParams", "float", 2, sizeof(cl_float2), params->getDeviceBuffer()));;
2588
    nb.addParameter(OpenCLNonbondedUtilities::ParameterInfo("bornForce", "real", 1, elementSize, bornForce->getDeviceBuffer()));;
Peter Eastman's avatar
Peter Eastman committed
2589
    cl.addForce(new OpenCLGBSAOBCForceInfo(nb.getNumForceBuffers(), force));
2590
2591
}

2592
double OpenCLCalcGBSAOBCForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
2593
    OpenCLNonbondedUtilities& nb = cl.getNonbondedUtilities();
2594
    bool deviceIsCpu = (cl.getDevice().getInfo<CL_DEVICE_TYPE>() == CL_DEVICE_TYPE_CPU);
2595
2596
2597
2598
    if (!hasCreatedKernels) {
        // These Kernels cannot be created in initialize(), because the OpenCLNonbondedUtilities has not been initialized yet then.

        hasCreatedKernels = true;
2599
        maxTiles = (nb.getUseCutoff() ? nb.getInteractingTiles().getSize() : 0);
2600
2601
2602
2603
2604
        map<string, string> defines;
        if (nb.getUseCutoff())
            defines["USE_CUTOFF"] = "1";
        if (nb.getUsePeriodic())
            defines["USE_PERIODIC"] = "1";
2605
        defines["CUTOFF_SQUARED"] = cl.doubleToString(nb.getCutoffDistance()*nb.getCutoffDistance());
2606
        defines["CUTOFF"] = cl.doubleToString(nb.getCutoffDistance());
2607
        defines["PREFACTOR"] = cl.doubleToString(prefactor);
2608
        defines["SURFACE_AREA_FACTOR"] = cl.doubleToString(surfaceAreaFactor);
2609
2610
2611
2612
        defines["NUM_ATOMS"] = cl.intToString(cl.getNumAtoms());
        defines["PADDED_NUM_ATOMS"] = cl.intToString(cl.getPaddedNumAtoms());
        defines["NUM_BLOCKS"] = cl.intToString(cl.getNumAtomBlocks());
        defines["FORCE_WORK_GROUP_SIZE"] = cl.intToString(nb.getForceThreadBlockSize());
2613
2614
2615
2616
2617
2618
2619
2620
        defines["TILE_SIZE"] = cl.intToString(OpenCLContext::TileSize);
        int numExclusionTiles = nb.getExclusionTiles().getSize();
        defines["NUM_TILES_WITH_EXCLUSIONS"] = cl.intToString(numExclusionTiles);
        int numContexts = cl.getPlatformData().contexts.size();
        int startExclusionIndex = cl.getContextIndex()*numExclusionTiles/numContexts;
        int endExclusionIndex = (cl.getContextIndex()+1)*numExclusionTiles/numContexts;
        defines["FIRST_EXCLUSION_TILE"] = cl.intToString(startExclusionIndex);
        defines["LAST_EXCLUSION_TILE"] = cl.intToString(endExclusionIndex);
2621
2622
2623
        string platformVendor = cl::Platform(cl.getDevice().getInfo<CL_DEVICE_PLATFORM>()).getInfo<CL_PLATFORM_VENDOR>();
        if (platformVendor == "Apple")
            defines["USE_APPLE_WORKAROUND"] = "1";
2624
2625
2626
2627
        string file;
        if (deviceIsCpu)
            file = OpenCLKernelSources::gbsaObc_cpu;
        else
2628
            file = OpenCLKernelSources::gbsaObc;
2629
        cl::Program program = cl.createProgram(file, defines);
2630
        bool useLong = cl.getSupports64BitGlobalAtomics();
2631
        int index = 0;
2632
        computeBornSumKernel = cl::Kernel(program, "computeBornSum");
2633
        computeBornSumKernel.setArg<cl::Buffer>(index++, (useLong ? longBornSum->getDeviceBuffer() : bornSum->getDeviceBuffer()));
2634
2635
        computeBornSumKernel.setArg<cl::Buffer>(index++, cl.getPosq().getDeviceBuffer());
        computeBornSumKernel.setArg<cl::Buffer>(index++, params->getDeviceBuffer());
2636
        if (nb.getUseCutoff()) {
2637
2638
            computeBornSumKernel.setArg<cl::Buffer>(index++, nb.getInteractingTiles().getDeviceBuffer());
            computeBornSumKernel.setArg<cl::Buffer>(index++, nb.getInteractionCount().getDeviceBuffer());
2639
            index += 5; // The periodic box size arguments are set when the kernel is executed.
2640
            computeBornSumKernel.setArg<cl_uint>(index++, maxTiles);
2641
            computeBornSumKernel.setArg<cl::Buffer>(index++, nb.getBlockCenters().getDeviceBuffer());
2642
            computeBornSumKernel.setArg<cl::Buffer>(index++, nb.getBlockBoundingBoxes().getDeviceBuffer());
2643
            computeBornSumKernel.setArg<cl::Buffer>(index++, nb.getInteractingAtoms().getDeviceBuffer());
2644
        }
2645
2646
        else
            computeBornSumKernel.setArg<cl_uint>(index++, cl.getNumAtomBlocks()*(cl.getNumAtomBlocks()+1)/2);
2647
        computeBornSumKernel.setArg<cl::Buffer>(index++, nb.getExclusionTiles().getDeviceBuffer());
2648
        force1Kernel = cl::Kernel(program, "computeGBSAForce1");
2649
        index = 0;
2650
2651
        force1Kernel.setArg<cl::Buffer>(index++, (useLong ? cl.getLongForceBuffer().getDeviceBuffer() : cl.getForceBuffers().getDeviceBuffer()));
        force1Kernel.setArg<cl::Buffer>(index++, (useLong ? longBornForce->getDeviceBuffer() : bornForce->getDeviceBuffer()));
2652
2653
2654
        force1Kernel.setArg<cl::Buffer>(index++, cl.getEnergyBuffer().getDeviceBuffer());
        force1Kernel.setArg<cl::Buffer>(index++, cl.getPosq().getDeviceBuffer());
        force1Kernel.setArg<cl::Buffer>(index++, bornRadii->getDeviceBuffer());
2655
        if (nb.getUseCutoff()) {
2656
2657
            force1Kernel.setArg<cl::Buffer>(index++, nb.getInteractingTiles().getDeviceBuffer());
            force1Kernel.setArg<cl::Buffer>(index++, nb.getInteractionCount().getDeviceBuffer());
2658
            index += 5; // The periodic box size arguments are set when the kernel is executed.
2659
            force1Kernel.setArg<cl_uint>(index++, maxTiles);
2660
            force1Kernel.setArg<cl::Buffer>(index++, nb.getBlockCenters().getDeviceBuffer());
2661
            force1Kernel.setArg<cl::Buffer>(index++, nb.getBlockBoundingBoxes().getDeviceBuffer());
2662
            force1Kernel.setArg<cl::Buffer>(index++, nb.getInteractingAtoms().getDeviceBuffer());
2663
        }
2664
2665
        else
            force1Kernel.setArg<cl_uint>(index++, cl.getNumAtomBlocks()*(cl.getNumAtomBlocks()+1)/2);
2666
        force1Kernel.setArg<cl::Buffer>(index++, nb.getExclusionTiles().getDeviceBuffer());
2667
        program = cl.createProgram(OpenCLKernelSources::gbsaObcReductions, defines);
2668
2669
        reduceBornSumKernel = cl::Kernel(program, "reduceBornSum");
        reduceBornSumKernel.setArg<cl_int>(0, cl.getPaddedNumAtoms());
Peter Eastman's avatar
Peter Eastman committed
2670
        reduceBornSumKernel.setArg<cl_int>(1, nb.getNumForceBuffers());
2671
2672
2673
        reduceBornSumKernel.setArg<cl_float>(2, 1.0f);
        reduceBornSumKernel.setArg<cl_float>(3, 0.8f);
        reduceBornSumKernel.setArg<cl_float>(4, 4.85f);
2674
        reduceBornSumKernel.setArg<cl::Buffer>(5, (useLong ? longBornSum->getDeviceBuffer() : bornSum->getDeviceBuffer()));
2675
2676
2677
        reduceBornSumKernel.setArg<cl::Buffer>(6, params->getDeviceBuffer());
        reduceBornSumKernel.setArg<cl::Buffer>(7, bornRadii->getDeviceBuffer());
        reduceBornSumKernel.setArg<cl::Buffer>(8, obcChain->getDeviceBuffer());
2678
        reduceBornForceKernel = cl::Kernel(program, "reduceBornForce");
2679
2680
2681
2682
2683
2684
2685
2686
2687
2688
        index = 0;
        reduceBornForceKernel.setArg<cl_int>(index++, cl.getPaddedNumAtoms());
        reduceBornForceKernel.setArg<cl_int>(index++, nb.getNumForceBuffers());
        reduceBornForceKernel.setArg<cl::Buffer>(index++, bornForce->getDeviceBuffer());
        if (useLong)
            reduceBornForceKernel.setArg<cl::Buffer>(index++, longBornForce->getDeviceBuffer());
        reduceBornForceKernel.setArg<cl::Buffer>(index++, cl.getEnergyBuffer().getDeviceBuffer());
        reduceBornForceKernel.setArg<cl::Buffer>(index++, params->getDeviceBuffer());
        reduceBornForceKernel.setArg<cl::Buffer>(index++, bornRadii->getDeviceBuffer());
        reduceBornForceKernel.setArg<cl::Buffer>(index++, obcChain->getDeviceBuffer());
2689
    }
2690
    if (nb.getUseCutoff()) {
2691
2692
        setPeriodicBoxArgs(cl, computeBornSumKernel, 5);
        setPeriodicBoxArgs(cl, force1Kernel, 7);
2693
2694
        if (maxTiles < nb.getInteractingTiles().getSize()) {
            maxTiles = nb.getInteractingTiles().getSize();
2695
            computeBornSumKernel.setArg<cl::Buffer>(3, nb.getInteractingTiles().getDeviceBuffer());
2696
2697
            computeBornSumKernel.setArg<cl_uint>(10, maxTiles);
            computeBornSumKernel.setArg<cl::Buffer>(13, nb.getInteractingAtoms().getDeviceBuffer());
2698
            force1Kernel.setArg<cl::Buffer>(5, nb.getInteractingTiles().getDeviceBuffer());
2699
2700
            force1Kernel.setArg<cl_uint>(12, maxTiles);
            force1Kernel.setArg<cl::Buffer>(15, nb.getInteractingAtoms().getDeviceBuffer());
2701
        }
2702
    }
2703
    cl.executeKernel(computeBornSumKernel, nb.getNumForceThreadBlocks()*nb.getForceThreadBlockSize(), nb.getForceThreadBlockSize());
2704
    cl.executeKernel(reduceBornSumKernel, cl.getPaddedNumAtoms());
2705
    cl.executeKernel(force1Kernel, nb.getNumForceThreadBlocks()*nb.getForceThreadBlockSize(), nb.getForceThreadBlockSize());
2706
    cl.executeKernel(reduceBornForceKernel, cl.getPaddedNumAtoms());
2707
    return 0.0;
2708
}
2709

2710
2711
2712
2713
2714
2715
2716
2717
2718
void OpenCLCalcGBSAOBCForceKernel::copyParametersToContext(ContextImpl& context, const GBSAOBCForce& force) {
    // Make sure the new parameters are acceptable.
    
    int numParticles = force.getNumParticles();
    if (numParticles != cl.getNumAtoms())
        throw OpenMMException("updateParametersInContext: The number of particles has changed");
    
    // Record the per-particle parameters.
    
2719
2720
    OpenCLArray& posq = cl.getPosq();
    mm_float4* posqf = (mm_float4*) cl.getPinnedBuffer();
2721
2722
    mm_double4* posqd = (mm_double4*) cl.getPinnedBuffer();
    posq.download(cl.getPinnedBuffer());
2723
    vector<mm_float2> paramsVector(cl.getPaddedNumAtoms(), mm_float2(1,1));
2724
2725
2726
2727
2728
2729
    const double dielectricOffset = 0.009;
    for (int i = 0; i < numParticles; i++) {
        double charge, radius, scalingFactor;
        force.getParticleParameters(i, charge, radius, scalingFactor);
        radius -= dielectricOffset;
        paramsVector[i] = mm_float2((float) radius, (float) (scalingFactor*radius));
2730
2731
2732
2733
        if (cl.getUseDoublePrecision())
            posqd[i].w = charge;
        else
            posqf[i].w = (float) charge;
2734
    }
2735
    posq.upload(cl.getPinnedBuffer());
2736
2737
2738
2739
2740
2741
2742
    params->upload(paramsVector);
    
    // Mark that the current reordering may be invalid.
    
    cl.invalidateMolecules();
}

2743
2744
2745
2746
2747
2748
2749
2750
2751
class OpenCLCustomGBForceInfo : public OpenCLForceInfo {
public:
    OpenCLCustomGBForceInfo(int requiredBuffers, const CustomGBForce& force) : OpenCLForceInfo(requiredBuffers), force(force) {
    }
    bool areParticlesIdentical(int particle1, int particle2) {
        vector<double> params1;
        vector<double> params2;
        force.getParticleParameters(particle1, params1);
        force.getParticleParameters(particle2, params2);
2752
        for (int i = 0; i < (int) params1.size(); i++)
2753
2754
2755
2756
2757
2758
2759
            if (params1[i] != params2[i])
                return false;
        return true;
    }
    int getNumParticleGroups() {
        return force.getNumExclusions();
    }
Peter Eastman's avatar
Peter Eastman committed
2760
    void getParticlesInGroup(int index, vector<int>& particles) {
2761
2762
2763
2764
2765
2766
2767
2768
2769
2770
2771
2772
2773
2774
2775
2776
2777
2778
        int particle1, particle2;
        force.getExclusionParticles(index, particle1, particle2);
        particles.resize(2);
        particles[0] = particle1;
        particles[1] = particle2;
    }
    bool areGroupsIdentical(int group1, int group2) {
        return true;
    }
private:
    const CustomGBForce& force;
};

OpenCLCalcCustomGBForceKernel::~OpenCLCalcCustomGBForceKernel() {
    if (params != NULL)
        delete params;
    if (computedValues != NULL)
        delete computedValues;
2779
2780
    if (energyDerivs != NULL)
        delete energyDerivs;
2781
2782
    if (energyDerivChain != NULL)
        delete energyDerivChain;
2783
2784
    if (longEnergyDerivs != NULL)
        delete longEnergyDerivs;
2785
2786
    if (globals != NULL)
        delete globals;
2787
2788
    if (valueBuffers != NULL)
        delete valueBuffers;
2789
2790
    if (longValueBuffers != NULL)
        delete longValueBuffers;
2791
2792
2793
2794
2795
    for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
        delete tabulatedFunctions[i];
}

void OpenCLCalcCustomGBForceKernel::initialize(const System& system, const CustomGBForce& force) {
2796
2797
    if (cl.getPlatformData().contexts.size() > 1)
        throw OpenMMException("CustomGBForce does not support using multiple OpenCL devices");
2798
    bool useExclusionsForValue = false;
2799
    numComputedValues = force.getNumComputedValues();
2800
2801
    vector<string> computedValueNames(force.getNumComputedValues());
    vector<string> computedValueExpressions(force.getNumComputedValues());
2802
2803
    if (force.getNumComputedValues() > 0) {
        CustomGBForce::ComputationType type;
2804
        force.getComputedValueParameters(0, computedValueNames[0], computedValueExpressions[0], type);
2805
2806
2807
2808
        if (type == CustomGBForce::SingleParticle)
            throw OpenMMException("OpenCLPlatform requires that the first computed value for a CustomGBForce be of type ParticlePair or ParticlePairNoExclusions.");
        useExclusionsForValue = (type == CustomGBForce::ParticlePair);
        for (int i = 1; i < force.getNumComputedValues(); i++) {
2809
            force.getComputedValueParameters(i, computedValueNames[i], computedValueExpressions[i], type);
2810
2811
2812
2813
2814
2815
2816
            if (type != CustomGBForce::SingleParticle)
                throw OpenMMException("OpenCLPlatform requires that a CustomGBForce only have one computed value of type ParticlePair or ParticlePairNoExclusions.");
        }
    }
    int forceIndex;
    for (forceIndex = 0; forceIndex < system.getNumForces() && &system.getForce(forceIndex) != &force; ++forceIndex)
        ;
2817
    string prefix = "custom"+cl.intToString(forceIndex)+"_";
2818
2819
2820
2821

    // Record parameters and exclusions.

    int numParticles = force.getNumParticles();
2822
2823
2824
2825
    int paddedNumParticles = cl.getPaddedNumAtoms();
    int numParams = force.getNumPerParticleParameters();
    params = new OpenCLParameterSet(cl, force.getNumPerParticleParameters(), paddedNumParticles, "customGBParameters", true);
    computedValues = new OpenCLParameterSet(cl, force.getNumComputedValues(), paddedNumParticles, "customGBComputedValues", true, cl.getUseDoublePrecision());
2826
    if (force.getNumGlobalParameters() > 0)
2827
        globals = OpenCLArray::create<cl_float>(cl, force.getNumGlobalParameters(), "customGBGlobals", CL_MEM_READ_ONLY);
2828
    vector<vector<cl_float> > paramVector(paddedNumParticles, vector<cl_float>(numParams, 0));
2829
2830
2831
2832
    vector<vector<int> > exclusionList(numParticles);
    for (int i = 0; i < numParticles; i++) {
        vector<double> parameters;
        force.getParticleParameters(i, parameters);
2833
        for (int j = 0; j < (int) parameters.size(); j++)
2834
2835
2836
2837
2838
2839
2840
2841
2842
2843
2844
2845
2846
2847
2848
            paramVector[i][j] = (cl_float) parameters[j];
        exclusionList[i].push_back(i);
    }
    for (int i = 0; i < force.getNumExclusions(); i++) {
        int particle1, particle2;
        force.getExclusionParticles(i, particle1, particle2);
        exclusionList[particle1].push_back(particle2);
        exclusionList[particle2].push_back(particle1);
    }
    params->setParameterValues(paramVector);

    // Record the tabulated functions.

    map<string, Lepton::CustomFunction*> functions;
    vector<pair<string, string> > functionDefinitions;
2849
    vector<const TabulatedFunction*> functionList;
2850
    stringstream tableArgs;
2851
2852
2853
    for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
        functionList.push_back(&force.getTabulatedFunction(i));
        string name = force.getTabulatedFunctionName(i);
2854
        string arrayName = prefix+"table"+cl.intToString(i);
2855
        functionDefinitions.push_back(make_pair(name, arrayName));
2856
        functions[name] = cl.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
peastman's avatar
peastman committed
2857
        int width;
2858
        vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
2859
        tabulatedFunctions.push_back(OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction"));
2860
        tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
peastman's avatar
peastman committed
2861
        cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(arrayName, "float", width, width*sizeof(float), tabulatedFunctions[tabulatedFunctions.size()-1]->getDeviceBuffer()));
2862
2863
2864
2865
        tableArgs << ", __global const float";
        if (width > 1)
            tableArgs << width;
        tableArgs << "* restrict " << arrayName;
2866
2867
    }

2868
    // Record the global parameters.
2869
2870
2871
2872
2873
2874
2875
2876
2877

    globalParamNames.resize(force.getNumGlobalParameters());
    globalParamValues.resize(force.getNumGlobalParameters());
    for (int i = 0; i < force.getNumGlobalParameters(); i++) {
        globalParamNames[i] = force.getGlobalParameterName(i);
        globalParamValues[i] = (cl_float) force.getGlobalParameterDefaultValue(i);
    }
    if (globals != NULL)
        globals->upload(globalParamValues);
2878
2879
2880

    // Record derivatives of expressions needed for the chain rule terms.

2881
    vector<vector<Lepton::ParsedExpression> > valueGradientExpressions(force.getNumComputedValues());
2882
    vector<vector<Lepton::ParsedExpression> > valueDerivExpressions(force.getNumComputedValues());
Peter Eastman's avatar
Peter Eastman committed
2883
    needParameterGradient = false;
2884
2885
2886
2887
2888
2889
2890
    for (int i = 1; i < force.getNumComputedValues(); i++) {
        Lepton::ParsedExpression ex = Lepton::Parser::parse(computedValueExpressions[i], functions).optimize();
        valueGradientExpressions[i].push_back(ex.differentiate("x").optimize());
        valueGradientExpressions[i].push_back(ex.differentiate("y").optimize());
        valueGradientExpressions[i].push_back(ex.differentiate("z").optimize());
        if (!isZeroExpression(valueGradientExpressions[i][0]) || !isZeroExpression(valueGradientExpressions[i][1]) || !isZeroExpression(valueGradientExpressions[i][2]))
            needParameterGradient = true;
2891
2892
         for (int j = 0; j < i; j++)
            valueDerivExpressions[i].push_back(ex.differentiate(computedValueNames[j]).optimize());
2893
    }
2894
    vector<vector<Lepton::ParsedExpression> > energyDerivExpressions(force.getNumEnergyTerms());
Peter Eastman's avatar
Peter Eastman committed
2895
    vector<bool> needChainForValue(force.getNumComputedValues(), false);
2896
2897
2898
2899
2900
2901
    for (int i = 0; i < force.getNumEnergyTerms(); i++) {
        string expression;
        CustomGBForce::ComputationType type;
        force.getEnergyTermParameters(i, expression, type);
        Lepton::ParsedExpression ex = Lepton::Parser::parse(expression, functions).optimize();
        for (int j = 0; j < force.getNumComputedValues(); j++) {
Peter Eastman's avatar
Peter Eastman committed
2902
            if (type == CustomGBForce::SingleParticle) {
2903
                energyDerivExpressions[i].push_back(ex.differentiate(computedValueNames[j]).optimize());
Peter Eastman's avatar
Peter Eastman committed
2904
2905
2906
                if (!isZeroExpression(energyDerivExpressions[i].back()))
                    needChainForValue[j] = true;
            }
2907
2908
            else {
                energyDerivExpressions[i].push_back(ex.differentiate(computedValueNames[j]+"1").optimize());
Peter Eastman's avatar
Peter Eastman committed
2909
2910
                if (!isZeroExpression(energyDerivExpressions[i].back()))
                    needChainForValue[j] = true;
2911
                energyDerivExpressions[i].push_back(ex.differentiate(computedValueNames[j]+"2").optimize());
Peter Eastman's avatar
Peter Eastman committed
2912
2913
                if (!isZeroExpression(energyDerivExpressions[i].back()))
                    needChainForValue[j] = true;
2914
2915
2916
            }
        }
    }
2917
    bool deviceIsCpu = (cl.getDevice().getInfo<CL_DEVICE_TYPE>() == CL_DEVICE_TYPE_CPU);
2918
    bool useLong = cl.getSupports64BitGlobalAtomics();
2919
    if (useLong) {
2920
        longEnergyDerivs = OpenCLArray::create<cl_long>(cl, force.getNumComputedValues()*cl.getPaddedNumAtoms(), "customGBLongEnergyDerivatives");
Peter Eastman's avatar
Peter Eastman committed
2921
        energyDerivs = new OpenCLParameterSet(cl, force.getNumComputedValues(), cl.getPaddedNumAtoms(), "customGBEnergyDerivatives", true);
2922
2923
    }
    else
Peter Eastman's avatar
Peter Eastman committed
2924
        energyDerivs = new OpenCLParameterSet(cl, force.getNumComputedValues(), cl.getPaddedNumAtoms()*cl.getNonbondedUtilities().getNumForceBuffers(), "customGBEnergyDerivatives", true);
2925
2926
    energyDerivChain = new OpenCLParameterSet(cl, force.getNumComputedValues(), cl.getPaddedNumAtoms(), "customGBEnergyDerivativeChain", true);

2927
2928
    // Create the kernels.

2929
2930
    bool useCutoff = (force.getNonbondedMethod() != CustomGBForce::NoCutoff);
    bool usePeriodic = (force.getNonbondedMethod() != CustomGBForce::NoCutoff && force.getNonbondedMethod() != CustomGBForce::CutoffNonPeriodic);
2931
2932
2933
    {
        // Create the N2 value kernel.

2934
        vector<pair<ExpressionTreeNode, string> > variables;
2935
        map<string, string> rename;
2936
2937
2938
2939
        ExpressionTreeNode rnode(new Operation::Variable("r"));
        variables.push_back(make_pair(rnode, "r"));
        variables.push_back(make_pair(ExpressionTreeNode(new Operation::Square(), rnode), "r2"));
        variables.push_back(make_pair(ExpressionTreeNode(new Operation::Reciprocal(), rnode), "invR"));
2940
2941
        for (int i = 0; i < force.getNumPerParticleParameters(); i++) {
            const string& name = force.getPerParticleParameterName(i);
2942
2943
            variables.push_back(makeVariable(name+"1", "params"+params->getParameterSuffix(i, "1")));
            variables.push_back(makeVariable(name+"2", "params"+params->getParameterSuffix(i, "2")));
2944
2945
            rename[name+"1"] = name+"2";
            rename[name+"2"] = name+"1";
2946
2947
2948
        }
        for (int i = 0; i < force.getNumGlobalParameters(); i++) {
            const string& name = force.getGlobalParameterName(i);
2949
            string value = "globals["+cl.intToString(i)+"]";
2950
            variables.push_back(makeVariable(name, value));
2951
        }
2952
2953
        map<string, Lepton::ParsedExpression> n2ValueExpressions;
        stringstream n2ValueSource;
2954
2955
2956
        Lepton::ParsedExpression ex = Lepton::Parser::parse(computedValueExpressions[0], functions).optimize();
        n2ValueExpressions["tempValue1 = "] = ex;
        n2ValueExpressions["tempValue2 = "] = ex.renameVariables(rename);
2957
        n2ValueSource << cl.getExpressionUtilities().createExpressions(n2ValueExpressions, variables, functionList, functionDefinitions, "temp");
2958
        map<string, string> replacements;
Peter Eastman's avatar
Peter Eastman committed
2959
2960
        string n2ValueStr = n2ValueSource.str();
        replacements["COMPUTE_VALUE"] = n2ValueStr;
2961
2962
        stringstream extraArgs, loadLocal1, loadLocal2, load1, load2;
        if (force.getNumGlobalParameters() > 0)
Peter Eastman's avatar
Peter Eastman committed
2963
            extraArgs << ", __global const float* globals";
Peter Eastman's avatar
Peter Eastman committed
2964
        pairValueUsesParam.resize(params->getBuffers().size(), false);
2965
2966
        for (int i = 0; i < (int) params->getBuffers().size(); i++) {
            const OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i];
2967
            string paramName = "params"+cl.intToString(i+1);
Peter Eastman's avatar
Peter Eastman committed
2968
2969
2970
2971
2972
2973
2974
2975
            if (n2ValueStr.find(paramName+"1") != n2ValueStr.npos || n2ValueStr.find(paramName+"2") != n2ValueStr.npos) {
                extraArgs << ", __global const " << buffer.getType() << "* restrict global_" << paramName << ", __local " << buffer.getType() << "* restrict local_" << paramName;
                loadLocal1 << "local_" << paramName << "[localAtomIndex] = " << paramName << "1;\n";
                loadLocal2 << "local_" << paramName << "[localAtomIndex] = global_" << paramName << "[j];\n";
                load1 << buffer.getType() << " " << paramName << "1 = global_" << paramName << "[atom1];\n";
                load2 << buffer.getType() << " " << paramName << "2 = local_" << paramName << "[atom2];\n";
                pairValueUsesParam[i] = true;
            }
2976
        }
2977
        replacements["PARAMETER_ARGUMENTS"] = extraArgs.str()+tableArgs.str();
2978
2979
2980
2981
2982
        replacements["LOAD_LOCAL_PARAMETERS_FROM_1"] = loadLocal1.str();
        replacements["LOAD_LOCAL_PARAMETERS_FROM_GLOBAL"] = loadLocal2.str();
        replacements["LOAD_ATOM1_PARAMETERS"] = load1.str();
        replacements["LOAD_ATOM2_PARAMETERS"] = load2.str();
        if (useCutoff)
2983
            pairValueDefines["USE_CUTOFF"] = "1";
2984
        if (usePeriodic)
2985
            pairValueDefines["USE_PERIODIC"] = "1";
2986
        if (useExclusionsForValue)
2987
2988
2989
2990
2991
2992
2993
            pairValueDefines["USE_EXCLUSIONS"] = "1";
        pairValueDefines["FORCE_WORK_GROUP_SIZE"] = cl.intToString(cl.getNonbondedUtilities().getForceThreadBlockSize());
        pairValueDefines["CUTOFF_SQUARED"] = cl.doubleToString(force.getCutoffDistance()*force.getCutoffDistance());
        pairValueDefines["NUM_ATOMS"] = cl.intToString(cl.getNumAtoms());
        pairValueDefines["PADDED_NUM_ATOMS"] = cl.intToString(cl.getPaddedNumAtoms());
        pairValueDefines["NUM_BLOCKS"] = cl.intToString(cl.getNumAtomBlocks());
        pairValueDefines["TILE_SIZE"] = cl.intToString(OpenCLContext::TileSize);
2994
2995
2996
2997
        string file;
        if (deviceIsCpu)
            file = OpenCLKernelSources::customGBValueN2_cpu;
        else
2998
2999
            file = OpenCLKernelSources::customGBValueN2;
        pairValueSrc = cl.replaceStrings(file, replacements);
3000
3001
        if (useExclusionsForValue)
            cl.getNonbondedUtilities().requestExclusions(exclusionList);
3002
3003
3004
3005
3006
3007
    }
    {
        // Create the kernel to reduce the N2 value and calculate other values.

        stringstream reductionSource, extraArgs;
        if (force.getNumGlobalParameters() > 0)
Peter Eastman's avatar
Peter Eastman committed
3008
            extraArgs << ", __global const float* globals";
3009
3010
        for (int i = 0; i < (int) params->getBuffers().size(); i++) {
            const OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i];
3011
            string paramName = "params"+cl.intToString(i+1);
3012
            extraArgs << ", __global const " << buffer.getType() << "* restrict " << paramName;
3013
3014
3015
        }
        for (int i = 0; i < (int) computedValues->getBuffers().size(); i++) {
            const OpenCLNonbondedUtilities::ParameterInfo& buffer = computedValues->getBuffers()[i];
3016
            string valueName = "values"+cl.intToString(i+1);
3017
            extraArgs << ", __global " << buffer.getType() << "* restrict global_" << valueName;
3018
3019
3020
            reductionSource << buffer.getType() << " local_" << valueName << ";\n";
        }
        reductionSource << "local_values" << computedValues->getParameterSuffix(0) << " = sum;\n";
3021
        map<string, string> variables;
3022
3023
3024
        variables["x"] = "pos.x";
        variables["y"] = "pos.y";
        variables["z"] = "pos.z";
3025
3026
3027
        for (int i = 0; i < force.getNumPerParticleParameters(); i++)
            variables[force.getPerParticleParameterName(i)] = "params"+params->getParameterSuffix(i, "[index]");
        for (int i = 0; i < force.getNumGlobalParameters(); i++)
3028
            variables[force.getGlobalParameterName(i)] = "globals["+cl.intToString(i)+"]";
3029
3030
3031
3032
        for (int i = 1; i < force.getNumComputedValues(); i++) {
            variables[computedValueNames[i-1]] = "local_values"+computedValues->getParameterSuffix(i-1);
            map<string, Lepton::ParsedExpression> valueExpressions;
            valueExpressions["local_values"+computedValues->getParameterSuffix(i)+" = "] = Lepton::Parser::parse(computedValueExpressions[i], functions).optimize();
3033
            reductionSource << cl.getExpressionUtilities().createExpressions(valueExpressions, variables, functionList, functionDefinitions, "value"+cl.intToString(i)+"_temp");
3034
        }
3035
        for (int i = 0; i < (int) computedValues->getBuffers().size(); i++) {
3036
            string valueName = "values"+cl.intToString(i+1);
3037
3038
3039
            reductionSource << "global_" << valueName << "[index] = local_" << valueName << ";\n";
        }
        map<string, string> replacements;
3040
        replacements["PARAMETER_ARGUMENTS"] = extraArgs.str()+tableArgs.str();
3041
3042
        replacements["COMPUTE_VALUES"] = reductionSource.str();
        map<string, string> defines;
3043
        defines["NUM_ATOMS"] = cl.intToString(cl.getNumAtoms());
3044
        cl::Program program = cl.createProgram(cl.replaceStrings(OpenCLKernelSources::customGBValuePerParticle, replacements), defines);
3045
3046
3047
3048
3049
        perParticleValueKernel = cl::Kernel(program, "computePerParticleValues");
    }
    {
        // Create the N2 energy kernel.

3050
3051
3052
3053
3054
        vector<pair<ExpressionTreeNode, string> > variables;
        ExpressionTreeNode rnode(new Operation::Variable("r"));
        variables.push_back(make_pair(rnode, "r"));
        variables.push_back(make_pair(ExpressionTreeNode(new Operation::Square(), rnode), "r2"));
        variables.push_back(make_pair(ExpressionTreeNode(new Operation::Reciprocal(), rnode), "invR"));
3055
3056
        for (int i = 0; i < force.getNumPerParticleParameters(); i++) {
            const string& name = force.getPerParticleParameterName(i);
3057
3058
            variables.push_back(makeVariable(name+"1", "params"+params->getParameterSuffix(i, "1")));
            variables.push_back(makeVariable(name+"2", "params"+params->getParameterSuffix(i, "2")));
3059
3060
        }
        for (int i = 0; i < force.getNumComputedValues(); i++) {
3061
3062
            variables.push_back(makeVariable(computedValueNames[i]+"1", "values"+computedValues->getParameterSuffix(i, "1")));
            variables.push_back(makeVariable(computedValueNames[i]+"2", "values"+computedValues->getParameterSuffix(i, "2")));
3063
3064
        }
        for (int i = 0; i < force.getNumGlobalParameters(); i++)
3065
            variables.push_back(makeVariable(force.getGlobalParameterName(i), "globals["+cl.intToString(i)+"]"));
3066
        stringstream n2EnergySource;
3067
        bool anyExclusions = (force.getNumExclusions() > 0);
3068
3069
3070
3071
3072
3073
        for (int i = 0; i < force.getNumEnergyTerms(); i++) {
            string expression;
            CustomGBForce::ComputationType type;
            force.getEnergyTermParameters(i, expression, type);
            if (type == CustomGBForce::SingleParticle)
                continue;
3074
            bool exclude = (anyExclusions && type == CustomGBForce::ParticlePair);
3075
            map<string, Lepton::ParsedExpression> n2EnergyExpressions;
3076
3077
            n2EnergyExpressions["tempEnergy += "] = Lepton::Parser::parse(expression, functions).optimize();
            n2EnergyExpressions["dEdR += "] = Lepton::Parser::parse(expression, functions).differentiate("r").optimize();
3078
3079
            if (useLong) {
                for (int j = 0; j < force.getNumComputedValues(); j++) {
Peter Eastman's avatar
Peter Eastman committed
3080
                    if (needChainForValue[j]) {
3081
3082
3083
                        string index = cl.intToString(j+1);
                        n2EnergyExpressions["/*"+cl.intToString(i+1)+"*/ deriv"+index+"_1 += "] = energyDerivExpressions[i][2*j];
                        n2EnergyExpressions["/*"+cl.intToString(i+1)+"*/ deriv"+index+"_2 += "] = energyDerivExpressions[i][2*j+1];
Peter Eastman's avatar
Peter Eastman committed
3084
                    }
3085
3086
3087
3088
                }
            }
            else {
                for (int j = 0; j < force.getNumComputedValues(); j++) {
Peter Eastman's avatar
Peter Eastman committed
3089
                    if (needChainForValue[j]) {
3090
3091
                        n2EnergyExpressions["/*"+cl.intToString(i+1)+"*/ deriv"+energyDerivs->getParameterSuffix(j, "_1")+" += "] = energyDerivExpressions[i][2*j];
                        n2EnergyExpressions["/*"+cl.intToString(i+1)+"*/ deriv"+energyDerivs->getParameterSuffix(j, "_2")+" += "] = energyDerivExpressions[i][2*j+1];
Peter Eastman's avatar
Peter Eastman committed
3092
                    }
3093
                }
3094
3095
3096
            }
            if (exclude)
                n2EnergySource << "if (!isExcluded) {\n";
3097
            n2EnergySource << cl.getExpressionUtilities().createExpressions(n2EnergyExpressions, variables, functionList, functionDefinitions, "temp");
3098
3099
            if (exclude)
                n2EnergySource << "}\n";
3100
3101
        }
        map<string, string> replacements;
Peter Eastman's avatar
Peter Eastman committed
3102
3103
        string n2EnergyStr = n2EnergySource.str();
        replacements["COMPUTE_INTERACTION"] = n2EnergyStr;
3104
        stringstream extraArgs, loadLocal1, loadLocal2, clearLocal, load1, load2, declare1, recordDeriv, storeDerivs1, storeDerivs2, declareTemps, setTemps;
3105
        if (force.getNumGlobalParameters() > 0)
Peter Eastman's avatar
Peter Eastman committed
3106
            extraArgs << ", __global const float* globals";
Peter Eastman's avatar
Peter Eastman committed
3107
        pairEnergyUsesParam.resize(params->getBuffers().size(), false);
3108
3109
        for (int i = 0; i < (int) params->getBuffers().size(); i++) {
            const OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i];
3110
            string paramName = "params"+cl.intToString(i+1);
Peter Eastman's avatar
Peter Eastman committed
3111
3112
3113
3114
3115
3116
3117
3118
            if (n2EnergyStr.find(paramName+"1") != n2EnergyStr.npos || n2EnergyStr.find(paramName+"2") != n2EnergyStr.npos) {
                extraArgs << ", __global const " << buffer.getType() << "* restrict global_" << paramName << ", __local " << buffer.getType() << "* restrict local_" << paramName;
                loadLocal1 << "local_" << paramName << "[localAtomIndex] = " << paramName << "1;\n";
                loadLocal2 << "local_" << paramName << "[localAtomIndex] = global_" << paramName << "[j];\n";
                load1 << buffer.getType() << " " << paramName << "1 = global_" << paramName << "[atom1];\n";
                load2 << buffer.getType() << " " << paramName << "2 = local_" << paramName << "[atom2];\n";
                pairEnergyUsesParam[i] = true;
            }
3119
        }
Peter Eastman's avatar
Peter Eastman committed
3120
        pairEnergyUsesValue.resize(computedValues->getBuffers().size(), false);
3121
3122
        for (int i = 0; i < (int) computedValues->getBuffers().size(); i++) {
            const OpenCLNonbondedUtilities::ParameterInfo& buffer = computedValues->getBuffers()[i];
3123
            string valueName = "values"+cl.intToString(i+1);
Peter Eastman's avatar
Peter Eastman committed
3124
3125
3126
3127
3128
3129
3130
3131
            if (n2EnergyStr.find(valueName+"1") != n2EnergyStr.npos || n2EnergyStr.find(valueName+"2") != n2EnergyStr.npos) {
                extraArgs << ", __global const " << buffer.getType() << "* restrict global_" << valueName << ", __local " << buffer.getType() << "* restrict local_" << valueName;
                loadLocal1 << "local_" << valueName << "[localAtomIndex] = " << valueName << "1;\n";
                loadLocal2 << "local_" << valueName << "[localAtomIndex] = global_" << valueName << "[j];\n";
                load1 << buffer.getType() << " " << valueName << "1 = global_" << valueName << "[atom1];\n";
                load2 << buffer.getType() << " " << valueName << "2 = local_" << valueName << "[atom2];\n";
                pairEnergyUsesValue[i] = true;
            }
3132
        }
3133
        if (useLong) {
3134
            extraArgs << ", __global long* restrict derivBuffers";
3135
            for (int i = 0; i < force.getNumComputedValues(); i++) {
3136
                string index = cl.intToString(i+1);
3137
                extraArgs << ", __local real* restrict local_deriv" << index;
3138
                clearLocal << "local_deriv" << index << "[localAtomIndex] = 0.0f;\n";
3139
3140
                declare1 << "real deriv" << index << "_1 = 0;\n";
                load2 << "real deriv" << index << "_2 = 0;\n";
3141
3142
3143
                recordDeriv << "local_deriv" << index << "[atom2] += deriv" << index << "_2;\n";
                storeDerivs1 << "STORE_DERIVATIVE_1(" << index << ")\n";
                storeDerivs2 << "STORE_DERIVATIVE_2(" << index << ")\n";
3144
                declareTemps << "__local real tempDerivBuffer" << index << "[64];\n";
3145
3146
3147
3148
3149
3150
                setTemps << "tempDerivBuffer" << index << "[get_local_id(0)] = deriv" << index << "_1;\n";
            }
        }
        else {
            for (int i = 0; i < (int) energyDerivs->getBuffers().size(); i++) {
                const OpenCLNonbondedUtilities::ParameterInfo& buffer = energyDerivs->getBuffers()[i];
3151
                string index = cl.intToString(i+1);
3152
                extraArgs << ", __global " << buffer.getType() << "* restrict derivBuffers" << index << ", __local " << buffer.getType() << "* restrict local_deriv" << index;
3153
3154
3155
3156
3157
3158
3159
3160
3161
                clearLocal << "local_deriv" << index << "[localAtomIndex] = 0.0f;\n";
                declare1 << buffer.getType() << " deriv" << index << "_1 = 0.0f;\n";
                load2 << buffer.getType() << " deriv" << index << "_2 = 0.0f;\n";
                recordDeriv << "local_deriv" << index << "[atom2] += deriv" << index << "_2;\n";
                storeDerivs1 << "STORE_DERIVATIVE_1(" << index << ")\n";
                storeDerivs2 << "STORE_DERIVATIVE_2(" << index << ")\n";
                declareTemps << "__local " << buffer.getType() << " tempDerivBuffer" << index << "[64];\n";
                setTemps << "tempDerivBuffer" << index << "[get_local_id(0)] = deriv" << index << "_1;\n";
            }
3162
        }
3163
3164
3165
        replacements["PARAMETER_ARGUMENTS"] = extraArgs.str()+tableArgs.str();
        replacements["LOAD_LOCAL_PARAMETERS_FROM_1"] = loadLocal1.str();
        replacements["LOAD_LOCAL_PARAMETERS_FROM_GLOBAL"] = loadLocal2.str();
3166
        replacements["CLEAR_LOCAL_DERIVATIVES"] = clearLocal.str();
3167
3168
        replacements["LOAD_ATOM1_PARAMETERS"] = load1.str();
        replacements["LOAD_ATOM2_PARAMETERS"] = load2.str();
3169
        replacements["DECLARE_ATOM1_DERIVATIVES"] = declare1.str();
3170
3171
3172
        replacements["RECORD_DERIVATIVE_2"] = recordDeriv.str();
        replacements["STORE_DERIVATIVES_1"] = storeDerivs1.str();
        replacements["STORE_DERIVATIVES_2"] = storeDerivs2.str();
3173
3174
        replacements["DECLARE_TEMP_BUFFERS"] = declareTemps.str();
        replacements["SET_TEMP_BUFFERS"] = setTemps.str();
3175
        if (useCutoff)
3176
            pairEnergyDefines["USE_CUTOFF"] = "1";
3177
        if (usePeriodic)
3178
            pairEnergyDefines["USE_PERIODIC"] = "1";
3179
        if (anyExclusions)
3180
3181
3182
3183
3184
3185
3186
            pairEnergyDefines["USE_EXCLUSIONS"] = "1";
        pairEnergyDefines["FORCE_WORK_GROUP_SIZE"] = cl.intToString(cl.getNonbondedUtilities().getForceThreadBlockSize());
        pairEnergyDefines["CUTOFF_SQUARED"] = cl.doubleToString(force.getCutoffDistance()*force.getCutoffDistance());
        pairEnergyDefines["NUM_ATOMS"] = cl.intToString(cl.getNumAtoms());
        pairEnergyDefines["PADDED_NUM_ATOMS"] = cl.intToString(cl.getPaddedNumAtoms());
        pairEnergyDefines["NUM_BLOCKS"] = cl.intToString(cl.getNumAtomBlocks());
        pairEnergyDefines["TILE_SIZE"] = cl.intToString(OpenCLContext::TileSize);
3187
3188
3189
3190
        string file;
        if (deviceIsCpu)
            file = OpenCLKernelSources::customGBEnergyN2_cpu;
        else
3191
3192
            file = OpenCLKernelSources::customGBEnergyN2;
        pairEnergySrc = cl.replaceStrings(file, replacements);
3193
3194
3195
3196
    }
    {
        // Create the kernel to reduce the derivatives and calculate per-particle energy terms.

3197
        stringstream compute, extraArgs, reduce;
3198
        if (force.getNumGlobalParameters() > 0)
Peter Eastman's avatar
Peter Eastman committed
3199
            extraArgs << ", __global const float* globals";
3200
3201
        for (int i = 0; i < (int) params->getBuffers().size(); i++) {
            const OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i];
3202
            string paramName = "params"+cl.intToString(i+1);
3203
            extraArgs << ", __global const " << buffer.getType() << "* restrict " << paramName;
3204
3205
3206
        }
        for (int i = 0; i < (int) computedValues->getBuffers().size(); i++) {
            const OpenCLNonbondedUtilities::ParameterInfo& buffer = computedValues->getBuffers()[i];
3207
            string valueName = "values"+cl.intToString(i+1);
3208
            extraArgs << ", __global const " << buffer.getType() << "* restrict " << valueName;
3209
        }
3210
3211
        for (int i = 0; i < (int) energyDerivs->getBuffers().size(); i++) {
            const OpenCLNonbondedUtilities::ParameterInfo& buffer = energyDerivs->getBuffers()[i];
3212
            string index = cl.intToString(i+1);
3213
            extraArgs << ", __global " << buffer.getType() << "* restrict derivBuffers" << index;
3214
3215
            compute << buffer.getType() << " deriv" << index << " = derivBuffers" << index << "[index];\n";
        }
3216
3217
3218
3219
3220
        for (int i = 0; i < (int) energyDerivChain->getBuffers().size(); i++) {
            const OpenCLNonbondedUtilities::ParameterInfo& buffer = energyDerivChain->getBuffers()[i];
            string index = cl.intToString(i+1);
            extraArgs << ", __global " << buffer.getType() << "* restrict derivChain" << index;
        }
3221
        if (useLong) {
3222
            extraArgs << ", __global const long* restrict derivBuffersIn";
3223
3224
            for (int i = 0; i < energyDerivs->getNumParameters(); ++i)
                reduce << "derivBuffers" << energyDerivs->getParameterSuffix(i, "[index]") <<
3225
                        " = (1.0f/0x100000000)*derivBuffersIn[index+PADDED_NUM_ATOMS*" << cl.intToString(i) << "];\n";
3226
3227
3228
        }
        else {
            for (int i = 0; i < (int) energyDerivs->getBuffers().size(); i++)
3229
                reduce << "REDUCE_VALUE(derivBuffers" << cl.intToString(i+1) << ", " << energyDerivs->getBuffers()[i].getType() << ")\n";
3230
        }
Peter Eastman's avatar
Peter Eastman committed
3231
3232
3233
        
        // Compute the various expressions.
        
3234
        map<string, string> variables;
3235
3236
3237
        variables["x"] = "pos.x";
        variables["y"] = "pos.y";
        variables["z"] = "pos.z";
3238
3239
3240
        for (int i = 0; i < force.getNumPerParticleParameters(); i++)
            variables[force.getPerParticleParameterName(i)] = "params"+params->getParameterSuffix(i, "[index]");
        for (int i = 0; i < force.getNumGlobalParameters(); i++)
3241
            variables[force.getGlobalParameterName(i)] = "globals["+cl.intToString(i)+"]";
3242
3243
        for (int i = 0; i < force.getNumComputedValues(); i++)
            variables[computedValueNames[i]] = "values"+computedValues->getParameterSuffix(i, "[index]");
Peter Eastman's avatar
Peter Eastman committed
3244
        map<string, Lepton::ParsedExpression> expressions;
3245
3246
3247
3248
3249
3250
        for (int i = 0; i < force.getNumEnergyTerms(); i++) {
            string expression;
            CustomGBForce::ComputationType type;
            force.getEnergyTermParameters(i, expression, type);
            if (type != CustomGBForce::SingleParticle)
                continue;
3251
            Lepton::ParsedExpression parsed = Lepton::Parser::parse(expression, functions).optimize();
3252
            expressions["/*"+cl.intToString(i+1)+"*/ energy += "] = parsed;
3253
            for (int j = 0; j < force.getNumComputedValues(); j++)
3254
                expressions["/*"+cl.intToString(i+1)+"*/ deriv"+energyDerivs->getParameterSuffix(j)+" += "] = energyDerivExpressions[i][j];
3255
3256
3257
3258
            Lepton::ParsedExpression gradx = parsed.differentiate("x").optimize();
            Lepton::ParsedExpression grady = parsed.differentiate("y").optimize();
            Lepton::ParsedExpression gradz = parsed.differentiate("z").optimize();
            if (!isZeroExpression(gradx))
3259
                expressions["/*"+cl.intToString(i+1)+"*/ force.x -= "] = gradx;
3260
            if (!isZeroExpression(grady))
3261
                expressions["/*"+cl.intToString(i+1)+"*/ force.y -= "] = grady;
3262
            if (!isZeroExpression(gradz))
3263
                expressions["/*"+cl.intToString(i+1)+"*/ force.z -= "] = gradz;
Peter Eastman's avatar
Peter Eastman committed
3264
3265
3266
        }
        for (int i = 1; i < force.getNumComputedValues(); i++)
            for (int j = 0; j < i; j++)
3267
                expressions["real dV"+cl.intToString(i)+"dV"+cl.intToString(j)+" = "] = valueDerivExpressions[i][j];
3268
        compute << cl.getExpressionUtilities().createExpressions(expressions, variables, functionList, functionDefinitions, "temp");
Peter Eastman's avatar
Peter Eastman committed
3269
3270
3271
        
        // Record values.
        
3272
3273
3274
3275
        for (int i = 0; i < (int) energyDerivs->getBuffers().size(); i++) {
            string index = cl.intToString(i+1);
            compute << "derivBuffers" << index << "[index] = deriv" << index << ";\n";
        }
Peter Eastman's avatar
Peter Eastman committed
3276
3277
        compute << "forceBuffers[index] = forceBuffers[index]+force;\n";
        for (int i = 1; i < force.getNumComputedValues(); i++) {
3278
            compute << "real totalDeriv"<<i<<" = dV"<<i<<"dV0";
Peter Eastman's avatar
Peter Eastman committed
3279
3280
3281
3282
            for (int j = 1; j < i; j++)
                compute << " + totalDeriv"<<j<<"*dV"<<i<<"dV"<<j;
            compute << ";\n";
            compute << "deriv"<<(i+1)<<" *= totalDeriv"<<i<<";\n";
3283
3284
        }
        for (int i = 0; i < (int) energyDerivs->getBuffers().size(); i++) {
3285
            string index = cl.intToString(i+1);
3286
            compute << "derivChain" << index << "[index] = deriv" << index << ";\n";
3287
3288
3289
        }
        map<string, string> replacements;
        replacements["PARAMETER_ARGUMENTS"] = extraArgs.str()+tableArgs.str();
3290
3291
        replacements["REDUCE_DERIVATIVES"] = reduce.str();
        replacements["COMPUTE_ENERGY"] = compute.str();
3292
        map<string, string> defines;
3293
3294
        defines["NUM_ATOMS"] = cl.intToString(cl.getNumAtoms());
        defines["PADDED_NUM_ATOMS"] = cl.intToString(cl.getPaddedNumAtoms());
3295
        cl::Program program = cl.createProgram(cl.replaceStrings(OpenCLKernelSources::customGBEnergyPerParticle, replacements), defines);
3296
        perParticleEnergyKernel = cl::Kernel(program, "computePerParticleEnergy");
3297
    }
Peter Eastman's avatar
Peter Eastman committed
3298
3299
3300
3301
3302
    if (needParameterGradient) {
        // Create the kernel to compute chain rule terms for computed values that depend explicitly on particle coordinates.

        stringstream compute, extraArgs;
        if (force.getNumGlobalParameters() > 0)
Peter Eastman's avatar
Peter Eastman committed
3303
            extraArgs << ", __global const float* globals";
Peter Eastman's avatar
Peter Eastman committed
3304
3305
        for (int i = 0; i < (int) params->getBuffers().size(); i++) {
            const OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i];
3306
            string paramName = "params"+cl.intToString(i+1);
3307
            extraArgs << ", __global const " << buffer.getType() << "* restrict " << paramName;
Peter Eastman's avatar
Peter Eastman committed
3308
3309
3310
        }
        for (int i = 0; i < (int) computedValues->getBuffers().size(); i++) {
            const OpenCLNonbondedUtilities::ParameterInfo& buffer = computedValues->getBuffers()[i];
3311
            string valueName = "values"+cl.intToString(i+1);
3312
            extraArgs << ", __global const " << buffer.getType() << "* restrict " << valueName;
Peter Eastman's avatar
Peter Eastman committed
3313
3314
3315
        }
        for (int i = 0; i < (int) energyDerivs->getBuffers().size(); i++) {
            const OpenCLNonbondedUtilities::ParameterInfo& buffer = energyDerivs->getBuffers()[i];
3316
            string index = cl.intToString(i+1);
3317
            extraArgs << ", __global " << buffer.getType() << "* restrict derivBuffers" << index;
Peter Eastman's avatar
Peter Eastman committed
3318
3319
3320
3321
3322
3323
3324
3325
3326
            compute << buffer.getType() << " deriv" << index << " = derivBuffers" << index << "[index];\n";
        }
        map<string, string> variables;
        variables["x"] = "pos.x";
        variables["y"] = "pos.y";
        variables["z"] = "pos.z";
        for (int i = 0; i < force.getNumPerParticleParameters(); i++)
            variables[force.getPerParticleParameterName(i)] = "params"+params->getParameterSuffix(i, "[index]");
        for (int i = 0; i < force.getNumGlobalParameters(); i++)
3327
            variables[force.getGlobalParameterName(i)] = "globals["+cl.intToString(i)+"]";
Peter Eastman's avatar
Peter Eastman committed
3328
3329
3330
        for (int i = 0; i < force.getNumComputedValues(); i++)
            variables[computedValueNames[i]] = "values"+computedValues->getParameterSuffix(i, "[index]");
        for (int i = 1; i < force.getNumComputedValues(); i++) {
3331
            string is = cl.intToString(i);
3332
            compute << "real4 dV"<<is<<"dR = (real4) 0;\n";
3333
3334
3335
            for (int j = 1; j < i; j++) {
                if (!isZeroExpression(valueDerivExpressions[i][j])) {
                    map<string, Lepton::ParsedExpression> derivExpressions;
3336
                    string js = cl.intToString(j);
3337
                    derivExpressions["real dV"+is+"dV"+js+" = "] = valueDerivExpressions[i][j];
3338
                    compute << cl.getExpressionUtilities().createExpressions(derivExpressions, variables, functionList, functionDefinitions, "temp_"+is+"_"+js);
3339
3340
3341
3342
                    compute << "dV"<<is<<"dR += dV"<<is<<"dV"<<js<<"*dV"<<js<<"dR;\n";
                }
            }
            map<string, Lepton::ParsedExpression> gradientExpressions;
Peter Eastman's avatar
Peter Eastman committed
3343
            if (!isZeroExpression(valueGradientExpressions[i][0]))
3344
                gradientExpressions["dV"+is+"dR.x += "] = valueGradientExpressions[i][0];
Peter Eastman's avatar
Peter Eastman committed
3345
            if (!isZeroExpression(valueGradientExpressions[i][1]))
3346
                gradientExpressions["dV"+is+"dR.y += "] = valueGradientExpressions[i][1];
Peter Eastman's avatar
Peter Eastman committed
3347
            if (!isZeroExpression(valueGradientExpressions[i][2]))
3348
                gradientExpressions["dV"+is+"dR.z += "] = valueGradientExpressions[i][2];
3349
            compute << cl.getExpressionUtilities().createExpressions(gradientExpressions, variables, functionList, functionDefinitions, "temp");
3350
3351
        }
        for (int i = 1; i < force.getNumComputedValues(); i++) {
3352
            string is = cl.intToString(i);
3353
            compute << "force -= deriv"<<energyDerivs->getParameterSuffix(i)<<"*dV"<<is<<"dR;\n";
Peter Eastman's avatar
Peter Eastman committed
3354
3355
3356
3357
3358
        }
        map<string, string> replacements;
        replacements["PARAMETER_ARGUMENTS"] = extraArgs.str()+tableArgs.str();
        replacements["COMPUTE_FORCES"] = compute.str();
        map<string, string> defines;
3359
        defines["NUM_ATOMS"] = cl.intToString(cl.getNumAtoms());
Peter Eastman's avatar
Peter Eastman committed
3360
3361
3362
        cl::Program program = cl.createProgram(cl.replaceStrings(OpenCLKernelSources::customGBGradientChainRule, replacements), defines);
        gradientChainRuleKernel = cl::Kernel(program, "computeGradientChainRuleTerms");
    }
3363
    {
Peter Eastman's avatar
Peter Eastman committed
3364
        // Create the code to calculate chain rules terms as part of the default nonbonded kernel.
3365

3366
        vector<pair<ExpressionTreeNode, string> > globalVariables;
3367
3368
        for (int i = 0; i < force.getNumGlobalParameters(); i++) {
            const string& name = force.getGlobalParameterName(i);
3369
            string value = "globals["+cl.intToString(i)+"]";
3370
            globalVariables.push_back(makeVariable(name, prefix+value));
3371
        }
3372
        vector<pair<ExpressionTreeNode, string> > variables = globalVariables;
3373
        map<string, string> rename;
3374
3375
3376
3377
        ExpressionTreeNode rnode(new Operation::Variable("r"));
        variables.push_back(make_pair(rnode, "r"));
        variables.push_back(make_pair(ExpressionTreeNode(new Operation::Square(), rnode), "r2"));
        variables.push_back(make_pair(ExpressionTreeNode(new Operation::Reciprocal(), rnode), "invR"));
3378
3379
        for (int i = 0; i < force.getNumPerParticleParameters(); i++) {
            const string& name = force.getPerParticleParameterName(i);
3380
3381
            variables.push_back(makeVariable(name+"1", prefix+"params"+params->getParameterSuffix(i, "1")));
            variables.push_back(makeVariable(name+"2", prefix+"params"+params->getParameterSuffix(i, "2")));
Peter Eastman's avatar
Peter Eastman committed
3382
3383
            rename[name+"1"] = name+"2";
            rename[name+"2"] = name+"1";
3384
3385
3386
3387
        }
        map<string, Lepton::ParsedExpression> derivExpressions;
        stringstream chainSource;
        Lepton::ParsedExpression dVdR = Lepton::Parser::parse(computedValueExpressions[0], functions).differentiate("r").optimize();
3388
3389
        derivExpressions["real dV0dR1 = "] = dVdR;
        derivExpressions["real dV0dR2 = "] = dVdR.renameVariables(rename);
3390
        chainSource << cl.getExpressionUtilities().createExpressions(derivExpressions, variables, functionList, functionDefinitions, prefix+"temp0_");
Peter Eastman's avatar
Peter Eastman committed
3391
3392
3393
3394
3395
3396
3397
        if (needChainForValue[0]) {
            if (useExclusionsForValue)
                chainSource << "if (!isExcluded) {\n";
            chainSource << "tempForce -= dV0dR1*" << prefix << "dEdV" << energyDerivs->getParameterSuffix(0, "1") << ";\n";
            chainSource << "tempForce -= dV0dR2*" << prefix << "dEdV" << energyDerivs->getParameterSuffix(0, "2") << ";\n";
            if (useExclusionsForValue)
                chainSource << "}\n";
3398
        }
Peter Eastman's avatar
Peter Eastman committed
3399
3400
3401
3402
        for (int i = 1; i < force.getNumComputedValues(); i++) {
            if (needChainForValue[i]) {
                chainSource << "tempForce -= dV0dR1*" << prefix << "dEdV" << energyDerivs->getParameterSuffix(i, "1") << ";\n";
                chainSource << "tempForce -= dV0dR2*" << prefix << "dEdV" << energyDerivs->getParameterSuffix(i, "2") << ";\n";
3403
            }
3404
3405
        }
        map<string, string> replacements;
Peter Eastman's avatar
Peter Eastman committed
3406
3407
        string chainStr = chainSource.str();
        replacements["COMPUTE_FORCE"] = chainStr;
3408
        string source = cl.replaceStrings(OpenCLKernelSources::customGBChainRule, replacements);
3409
3410
        vector<OpenCLNonbondedUtilities::ParameterInfo> parameters;
        vector<OpenCLNonbondedUtilities::ParameterInfo> arguments;
3411
3412
        for (int i = 0; i < (int) params->getBuffers().size(); i++) {
            const OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i];
3413
            string paramName = prefix+"params"+cl.intToString(i+1);
Peter Eastman's avatar
Peter Eastman committed
3414
3415
            if (chainStr.find(paramName+"1") != chainStr.npos || chainStr.find(paramName+"2") != chainStr.npos)
                parameters.push_back(OpenCLNonbondedUtilities::ParameterInfo(paramName, buffer.getComponentType(), buffer.getNumComponents(), buffer.getSize(), buffer.getMemory()));
3416
3417
3418
        }
        for (int i = 0; i < (int) computedValues->getBuffers().size(); i++) {
            const OpenCLNonbondedUtilities::ParameterInfo& buffer = computedValues->getBuffers()[i];
3419
            string paramName = prefix+"values"+cl.intToString(i+1);
Peter Eastman's avatar
Peter Eastman committed
3420
3421
            if (chainStr.find(paramName+"1") != chainStr.npos || chainStr.find(paramName+"2") != chainStr.npos)
                parameters.push_back(OpenCLNonbondedUtilities::ParameterInfo(paramName, buffer.getComponentType(), buffer.getNumComponents(), buffer.getSize(), buffer.getMemory()));
3422
        }
3423
        for (int i = 0; i < (int) energyDerivChain->getBuffers().size(); i++) {
Peter Eastman's avatar
Peter Eastman committed
3424
            if (needChainForValue[i]) { 
3425
                const OpenCLNonbondedUtilities::ParameterInfo& buffer = energyDerivChain->getBuffers()[i];
3426
                string paramName = prefix+"dEdV"+cl.intToString(i+1);
Peter Eastman's avatar
Peter Eastman committed
3427
3428
                parameters.push_back(OpenCLNonbondedUtilities::ParameterInfo(paramName, buffer.getComponentType(), buffer.getNumComponents(), buffer.getSize(), buffer.getMemory()));
            }
3429
3430
3431
        }
        if (globals != NULL) {
            globals->upload(globalParamValues);
3432
3433
            arguments.push_back(OpenCLNonbondedUtilities::ParameterInfo(prefix+"globals", "float", 1, sizeof(cl_float), globals->getDeviceBuffer()));
        }
3434
        cl.getNonbondedUtilities().addInteraction(useCutoff, usePeriodic, force.getNumExclusions() > 0, force.getCutoffDistance(), exclusionList, source, force.getForceGroup());
Peter Eastman's avatar
Peter Eastman committed
3435
3436
3437
3438
        for (int i = 0; i < (int) parameters.size(); i++)
            cl.getNonbondedUtilities().addParameter(parameters[i]);
        for (int i = 0; i < (int) arguments.size(); i++)
            cl.getNonbondedUtilities().addArgument(arguments[i]);
3439
3440
    }
    cl.addForce(new OpenCLCustomGBForceInfo(cl.getNonbondedUtilities().getNumForceBuffers(), force));
3441
    if (useLong)
3442
        cl.addAutoclearBuffer(*longEnergyDerivs);
Peter Eastman's avatar
Peter Eastman committed
3443
3444
3445
    else {
        for (int i = 0; i < (int) energyDerivs->getBuffers().size(); i++) {
            const OpenCLNonbondedUtilities::ParameterInfo& buffer = energyDerivs->getBuffers()[i];
3446
            cl.addAutoclearBuffer(buffer.getMemory(), buffer.getSize()*energyDerivs->getNumObjects());
Peter Eastman's avatar
Peter Eastman committed
3447
3448
        }
    }
3449
3450
}

3451
double OpenCLCalcCustomGBForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
3452
    bool deviceIsCpu = (cl.getDevice().getInfo<CL_DEVICE_TYPE>() == CL_DEVICE_TYPE_CPU);
3453
    OpenCLNonbondedUtilities& nb = cl.getNonbondedUtilities();
3454
    int elementSize = (cl.getUseDoublePrecision() ? sizeof(cl_double) : sizeof(cl_float));
3455
3456
    if (!hasInitializedKernels) {
        hasInitializedKernels = true;
3457
3458
3459
3460
3461
3462
3463
3464
3465
3466
3467
3468
        
        // These two kernels can't be compiled in initialize(), because the nonbonded utilities object
        // has not yet been initialized then.

        {
            int numExclusionTiles = nb.getExclusionTiles().getSize();
            pairValueDefines["NUM_TILES_WITH_EXCLUSIONS"] = cl.intToString(numExclusionTiles);
            int numContexts = cl.getPlatformData().contexts.size();
            int startExclusionIndex = cl.getContextIndex()*numExclusionTiles/numContexts;
            int endExclusionIndex = (cl.getContextIndex()+1)*numExclusionTiles/numContexts;
            pairValueDefines["FIRST_EXCLUSION_TILE"] = cl.intToString(startExclusionIndex);
            pairValueDefines["LAST_EXCLUSION_TILE"] = cl.intToString(endExclusionIndex);
3469
            pairValueDefines["CUTOFF"] = cl.doubleToString(nb.getCutoffDistance());
3470
3471
3472
3473
3474
3475
3476
3477
3478
3479
3480
3481
3482
            cl::Program program = cl.createProgram(pairValueSrc, pairValueDefines);
            pairValueKernel = cl::Kernel(program, "computeN2Value");
            pairValueSrc = "";
            pairValueDefines.clear();
        }
        {
            int numExclusionTiles = nb.getExclusionTiles().getSize();
            pairEnergyDefines["NUM_TILES_WITH_EXCLUSIONS"] = cl.intToString(numExclusionTiles);
            int numContexts = cl.getPlatformData().contexts.size();
            int startExclusionIndex = cl.getContextIndex()*numExclusionTiles/numContexts;
            int endExclusionIndex = (cl.getContextIndex()+1)*numExclusionTiles/numContexts;
            pairEnergyDefines["FIRST_EXCLUSION_TILE"] = cl.intToString(startExclusionIndex);
            pairEnergyDefines["LAST_EXCLUSION_TILE"] = cl.intToString(endExclusionIndex);
3483
            pairEnergyDefines["CUTOFF"] = cl.doubleToString(nb.getCutoffDistance());
3484
3485
3486
3487
3488
3489
3490
3491
            cl::Program program = cl.createProgram(pairEnergySrc, pairEnergyDefines);
            pairEnergyKernel = cl::Kernel(program, "computeN2Energy");
            pairEnergySrc = "";
            pairEnergyDefines.clear();
        }

        // Set arguments for kernels.
        
3492
        maxTiles = (nb.getUseCutoff() ? nb.getInteractingTiles().getSize() : 0);
3493
        bool useLong = cl.getSupports64BitGlobalAtomics();
3494
        if (useLong) {
3495
            longValueBuffers = OpenCLArray::create<cl_long>(cl, cl.getPaddedNumAtoms(), "customGBLongValueBuffers");
3496
3497
            cl.addAutoclearBuffer(*longValueBuffers);
            cl.clearBuffer(*longValueBuffers);
3498
3499
        }
        else {
3500
            valueBuffers = new OpenCLArray(cl, cl.getPaddedNumAtoms()*nb.getNumForceBuffers(), elementSize, "customGBValueBuffers");
3501
            cl.addAutoclearBuffer(*valueBuffers);
3502
3503
            cl.clearBuffer(*valueBuffers);
        }
3504
3505
        int index = 0;
        pairValueKernel.setArg<cl::Buffer>(index++, cl.getPosq().getDeviceBuffer());
3506
        pairValueKernel.setArg(index++, (deviceIsCpu ? OpenCLContext::TileSize : nb.getForceThreadBlockSize())*4*elementSize, NULL);
3507
        pairValueKernel.setArg<cl::Buffer>(index++, cl.getNonbondedUtilities().getExclusions().getDeviceBuffer());
3508
        pairValueKernel.setArg<cl::Buffer>(index++, cl.getNonbondedUtilities().getExclusionTiles().getDeviceBuffer());
3509
        pairValueKernel.setArg<cl::Buffer>(index++, useLong ? longValueBuffers->getDeviceBuffer() : valueBuffers->getDeviceBuffer());
3510
        pairValueKernel.setArg(index++, (deviceIsCpu ? OpenCLContext::TileSize : nb.getForceThreadBlockSize())*elementSize, NULL);
3511
3512
3513
        if (nb.getUseCutoff()) {
            pairValueKernel.setArg<cl::Buffer>(index++, nb.getInteractingTiles().getDeviceBuffer());
            pairValueKernel.setArg<cl::Buffer>(index++, nb.getInteractionCount().getDeviceBuffer());
3514
            index += 5; // Periodic box size arguments are set when the kernel is executed.
3515
            pairValueKernel.setArg<cl_uint>(index++, maxTiles);
3516
            pairValueKernel.setArg<cl::Buffer>(index++, nb.getBlockCenters().getDeviceBuffer());
3517
            pairValueKernel.setArg<cl::Buffer>(index++, nb.getBlockBoundingBoxes().getDeviceBuffer());
3518
            pairValueKernel.setArg<cl::Buffer>(index++, nb.getInteractingAtoms().getDeviceBuffer());
3519
        }
3520
3521
        else
            pairValueKernel.setArg<cl_uint>(index++, cl.getNumAtomBlocks()*(cl.getNumAtomBlocks()+1)/2);
3522
3523
3524
        if (globals != NULL)
            pairValueKernel.setArg<cl::Buffer>(index++, globals->getDeviceBuffer());
        for (int i = 0; i < (int) params->getBuffers().size(); i++) {
Peter Eastman's avatar
Peter Eastman committed
3525
3526
3527
3528
3529
            if (pairValueUsesParam[i]) {
                const OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i];
                pairValueKernel.setArg<cl::Memory>(index++, buffer.getMemory());
                pairValueKernel.setArg(index++, (deviceIsCpu ? OpenCLContext::TileSize : nb.getForceThreadBlockSize())*buffer.getSize(), NULL);
            }
3530
        }
3531
3532
        for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
            pairValueKernel.setArg<cl::Buffer>(index++, tabulatedFunctions[i]->getDeviceBuffer());
3533
        index = 0;
3534
3535
        perParticleValueKernel.setArg<cl_int>(index++, cl.getPaddedNumAtoms());
        perParticleValueKernel.setArg<cl_int>(index++, nb.getNumForceBuffers());
3536
        perParticleValueKernel.setArg<cl::Buffer>(index++, cl.getPosq().getDeviceBuffer());
3537
        perParticleValueKernel.setArg<cl::Buffer>(index++, useLong ? longValueBuffers->getDeviceBuffer() : valueBuffers->getDeviceBuffer());
3538
        if (globals != NULL)
3539
            perParticleValueKernel.setArg<cl::Buffer>(index++, globals->getDeviceBuffer());
3540
        for (int i = 0; i < (int) params->getBuffers().size(); i++)
3541
            perParticleValueKernel.setArg<cl::Memory>(index++, params->getBuffers()[i].getMemory());
3542
        for (int i = 0; i < (int) computedValues->getBuffers().size(); i++)
3543
            perParticleValueKernel.setArg<cl::Memory>(index++, computedValues->getBuffers()[i].getMemory());
3544
3545
        for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
            perParticleValueKernel.setArg<cl::Buffer>(index++, tabulatedFunctions[i]->getDeviceBuffer());
3546
        index = 0;
3547
        pairEnergyKernel.setArg<cl::Buffer>(index++, useLong ? cl.getLongForceBuffer().getDeviceBuffer() : cl.getForceBuffers().getDeviceBuffer());
3548
        pairEnergyKernel.setArg<cl::Buffer>(index++, cl.getEnergyBuffer().getDeviceBuffer());
3549
        pairEnergyKernel.setArg(index++, (deviceIsCpu ? OpenCLContext::TileSize : nb.getForceThreadBlockSize())*4*elementSize, NULL);
3550
        pairEnergyKernel.setArg<cl::Buffer>(index++, cl.getPosq().getDeviceBuffer());
3551
        pairEnergyKernel.setArg(index++, (deviceIsCpu ? OpenCLContext::TileSize : nb.getForceThreadBlockSize())*4*elementSize, NULL);
3552
        pairEnergyKernel.setArg<cl::Buffer>(index++, cl.getNonbondedUtilities().getExclusions().getDeviceBuffer());
3553
        pairEnergyKernel.setArg<cl::Buffer>(index++, cl.getNonbondedUtilities().getExclusionTiles().getDeviceBuffer());
3554
3555
3556
        if (nb.getUseCutoff()) {
            pairEnergyKernel.setArg<cl::Buffer>(index++, nb.getInteractingTiles().getDeviceBuffer());
            pairEnergyKernel.setArg<cl::Buffer>(index++, nb.getInteractionCount().getDeviceBuffer());
3557
            index += 5; // Periodic box size arguments are set when the kernel is executed.
3558
            pairEnergyKernel.setArg<cl_uint>(index++, maxTiles);
3559
            pairEnergyKernel.setArg<cl::Buffer>(index++, nb.getBlockCenters().getDeviceBuffer());
3560
            pairEnergyKernel.setArg<cl::Buffer>(index++, nb.getBlockBoundingBoxes().getDeviceBuffer());
3561
            pairEnergyKernel.setArg<cl::Buffer>(index++, nb.getInteractingAtoms().getDeviceBuffer());
3562
        }
3563
3564
        else
            pairEnergyKernel.setArg<cl_uint>(index++, cl.getNumAtomBlocks()*(cl.getNumAtomBlocks()+1)/2);
3565
3566
3567
        if (globals != NULL)
            pairEnergyKernel.setArg<cl::Buffer>(index++, globals->getDeviceBuffer());
        for (int i = 0; i < (int) params->getBuffers().size(); i++) {
Peter Eastman's avatar
Peter Eastman committed
3568
3569
3570
3571
3572
            if (pairEnergyUsesParam[i]) {
                const OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i];
                pairEnergyKernel.setArg<cl::Memory>(index++, buffer.getMemory());
                pairEnergyKernel.setArg(index++, (deviceIsCpu ? OpenCLContext::TileSize : nb.getForceThreadBlockSize())*buffer.getSize(), NULL);
            }
3573
3574
        }
        for (int i = 0; i < (int) computedValues->getBuffers().size(); i++) {
Peter Eastman's avatar
Peter Eastman committed
3575
3576
3577
3578
3579
            if (pairEnergyUsesValue[i]) {
                const OpenCLNonbondedUtilities::ParameterInfo& buffer = computedValues->getBuffers()[i];
                pairEnergyKernel.setArg<cl::Memory>(index++, buffer.getMemory());
                pairEnergyKernel.setArg(index++, (deviceIsCpu ? OpenCLContext::TileSize : nb.getForceThreadBlockSize())*buffer.getSize(), NULL);
            }
3580
        }
3581
3582
3583
        if (useLong) {
            pairEnergyKernel.setArg<cl::Memory>(index++, longEnergyDerivs->getDeviceBuffer());
            for (int i = 0; i < numComputedValues; ++i)
3584
                pairEnergyKernel.setArg(index++, nb.getForceThreadBlockSize()*elementSize, NULL);
3585
3586
3587
3588
3589
3590
3591
        }
        else {
            for (int i = 0; i < (int) energyDerivs->getBuffers().size(); i++) {
                const OpenCLNonbondedUtilities::ParameterInfo& buffer = energyDerivs->getBuffers()[i];
                pairEnergyKernel.setArg<cl::Memory>(index++, buffer.getMemory());
                pairEnergyKernel.setArg(index++, (deviceIsCpu ? OpenCLContext::TileSize : nb.getForceThreadBlockSize())*buffer.getSize(), NULL);
            }
3592
        }
3593
3594
        for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
            pairEnergyKernel.setArg<cl::Buffer>(index++, tabulatedFunctions[i]->getDeviceBuffer());
3595
3596
3597
        index = 0;
        perParticleEnergyKernel.setArg<cl_int>(index++, cl.getPaddedNumAtoms());
        perParticleEnergyKernel.setArg<cl_int>(index++, nb.getNumForceBuffers());
3598
        perParticleEnergyKernel.setArg<cl::Buffer>(index++, cl.getForceBuffers().getDeviceBuffer());
3599
        perParticleEnergyKernel.setArg<cl::Buffer>(index++, cl.getEnergyBuffer().getDeviceBuffer());
3600
        perParticleEnergyKernel.setArg<cl::Buffer>(index++, cl.getPosq().getDeviceBuffer());
3601
3602
3603
        if (globals != NULL)
            perParticleEnergyKernel.setArg<cl::Buffer>(index++, globals->getDeviceBuffer());
        for (int i = 0; i < (int) params->getBuffers().size(); i++)
3604
            perParticleEnergyKernel.setArg<cl::Memory>(index++, params->getBuffers()[i].getMemory());
3605
        for (int i = 0; i < (int) computedValues->getBuffers().size(); i++)
3606
            perParticleEnergyKernel.setArg<cl::Memory>(index++, computedValues->getBuffers()[i].getMemory());
3607
        for (int i = 0; i < (int) energyDerivs->getBuffers().size(); i++)
3608
            perParticleEnergyKernel.setArg<cl::Memory>(index++, energyDerivs->getBuffers()[i].getMemory());
3609
3610
        for (int i = 0; i < (int) energyDerivChain->getBuffers().size(); i++)
            perParticleEnergyKernel.setArg<cl::Memory>(index++, energyDerivChain->getBuffers()[i].getMemory());
3611
3612
        if (useLong)
            perParticleEnergyKernel.setArg<cl::Memory>(index++, longEnergyDerivs->getDeviceBuffer());
3613
3614
        for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
            perParticleEnergyKernel.setArg<cl::Buffer>(index++, tabulatedFunctions[i]->getDeviceBuffer());
Peter Eastman's avatar
Peter Eastman committed
3615
3616
3617
3618
3619
3620
3621
3622
3623
3624
3625
3626
3627
        if (needParameterGradient) {
            index = 0;
            gradientChainRuleKernel.setArg<cl::Buffer>(index++, cl.getForceBuffers().getDeviceBuffer());
            gradientChainRuleKernel.setArg<cl::Buffer>(index++, cl.getPosq().getDeviceBuffer());
            if (globals != NULL)
                gradientChainRuleKernel.setArg<cl::Buffer>(index++, globals->getDeviceBuffer());
            for (int i = 0; i < (int) params->getBuffers().size(); i++)
                gradientChainRuleKernel.setArg<cl::Memory>(index++, params->getBuffers()[i].getMemory());
            for (int i = 0; i < (int) computedValues->getBuffers().size(); i++)
                gradientChainRuleKernel.setArg<cl::Memory>(index++, computedValues->getBuffers()[i].getMemory());
            for (int i = 0; i < (int) energyDerivs->getBuffers().size(); i++)
                gradientChainRuleKernel.setArg<cl::Memory>(index++, energyDerivs->getBuffers()[i].getMemory());
        }
3628
3629
3630
    }
    if (globals != NULL) {
        bool changed = false;
3631
        for (int i = 0; i < (int) globalParamNames.size(); i++) {
3632
3633
3634
3635
3636
3637
3638
3639
            cl_float value = (cl_float) context.getParameter(globalParamNames[i]);
            if (value != globalParamValues[i])
                changed = true;
            globalParamValues[i] = value;
        }
        if (changed)
            globals->upload(globalParamValues);
    }
3640
    if (nb.getUseCutoff()) {
3641
3642
        setPeriodicBoxArgs(cl, pairValueKernel, 8);
        setPeriodicBoxArgs(cl, pairEnergyKernel, 9);
3643
3644
        if (maxTiles < nb.getInteractingTiles().getSize()) {
            maxTiles = nb.getInteractingTiles().getSize();
3645
            pairValueKernel.setArg<cl::Buffer>(6, nb.getInteractingTiles().getDeviceBuffer());
3646
3647
            pairValueKernel.setArg<cl_uint>(13, maxTiles);
            pairValueKernel.setArg<cl::Buffer>(16, nb.getInteractingAtoms().getDeviceBuffer());
3648
            pairEnergyKernel.setArg<cl::Buffer>(7, nb.getInteractingTiles().getDeviceBuffer());
3649
3650
            pairEnergyKernel.setArg<cl_uint>(14, maxTiles);
            pairEnergyKernel.setArg<cl::Buffer>(17, nb.getInteractingAtoms().getDeviceBuffer());
3651
        }
3652
    }
3653
    cl.executeKernel(pairValueKernel, nb.getNumForceThreadBlocks()*nb.getForceThreadBlockSize(), nb.getForceThreadBlockSize());
3654
    cl.executeKernel(perParticleValueKernel, cl.getPaddedNumAtoms());
3655
    cl.executeKernel(pairEnergyKernel, nb.getNumForceThreadBlocks()*nb.getForceThreadBlockSize(), nb.getForceThreadBlockSize());
3656
    cl.executeKernel(perParticleEnergyKernel, cl.getPaddedNumAtoms());
Peter Eastman's avatar
Peter Eastman committed
3657
3658
    if (needParameterGradient)
        cl.executeKernel(gradientChainRuleKernel, cl.getPaddedNumAtoms());
3659
3660
3661
    return 0.0;
}

3662
3663
3664
3665
3666
3667
3668
void OpenCLCalcCustomGBForceKernel::copyParametersToContext(ContextImpl& context, const CustomGBForce& force) {
    int numParticles = force.getNumParticles();
    if (numParticles != cl.getNumAtoms())
        throw OpenMMException("updateParametersInContext: The number of particles has changed");
    
    // Record the per-particle parameters.
    
3669
    vector<vector<cl_float> > paramVector(cl.getPaddedNumAtoms(), vector<cl_float>(force.getNumPerParticleParameters(), 0));
3670
3671
3672
3673
3674
3675
3676
3677
3678
3679
3680
3681
3682
    vector<double> parameters;
    for (int i = 0; i < numParticles; i++) {
        force.getParticleParameters(i, parameters);
        for (int j = 0; j < (int) parameters.size(); j++)
            paramVector[i][j] = (cl_float) parameters[j];
    }
    params->setParameterValues(paramVector);
    
    // Mark that the current reordering may be invalid.
    
    cl.invalidateMolecules();
}

3683
3684
class OpenCLCustomExternalForceInfo : public OpenCLForceInfo {
public:
Peter Eastman's avatar
Peter Eastman committed
3685
    OpenCLCustomExternalForceInfo(const CustomExternalForce& force, int numParticles) : OpenCLForceInfo(0), force(force), indices(numParticles, -1) {
3686
3687
3688
3689
3690
3691
3692
3693
3694
3695
3696
3697
3698
3699
3700
3701
3702
3703
3704
        vector<double> params;
        for (int i = 0; i < force.getNumParticles(); i++) {
            int particle;
            force.getParticleParameters(i, particle, params);
            indices[particle] = i;
        }
    }
    bool areParticlesIdentical(int particle1, int particle2) {
        particle1 = indices[particle1];
        particle2 = indices[particle2];
        if (particle1 == -1 && particle2 == -1)
            return true;
        if (particle1 == -1 || particle2 == -1)
            return false;
        int temp;
        vector<double> params1;
        vector<double> params2;
        force.getParticleParameters(particle1, temp, params1);
        force.getParticleParameters(particle2, temp, params2);
3705
        for (int i = 0; i < (int) params1.size(); i++)
3706
3707
3708
3709
3710
3711
3712
3713
3714
3715
3716
3717
3718
3719
3720
3721
3722
            if (params1[i] != params2[i])
                return false;
        return true;
    }
private:
    const CustomExternalForce& force;
    vector<int> indices;
};

OpenCLCalcCustomExternalForceKernel::~OpenCLCalcCustomExternalForceKernel() {
    if (params != NULL)
        delete params;
    if (globals != NULL)
        delete globals;
}

void OpenCLCalcCustomExternalForceKernel::initialize(const System& system, const CustomExternalForce& force) {
3723
3724
3725
3726
3727
3728
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*force.getNumParticles()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*force.getNumParticles()/numContexts;
    numParticles = endIndex-startIndex;
    if (numParticles == 0)
        return;
3729
    vector<vector<int> > atoms(numParticles, vector<int>(1));
3730
3731
    params = new OpenCLParameterSet(cl, force.getNumPerParticleParameters(), numParticles, "customExternalParams");
    vector<vector<cl_float> > paramVector(numParticles);
3732
3733
    for (int i = 0; i < numParticles; i++) {
        vector<double> parameters;
3734
        force.getParticleParameters(startIndex+i, atoms[i][0], parameters);
3735
        paramVector[i].resize(parameters.size());
3736
        for (int j = 0; j < (int) parameters.size(); j++)
3737
            paramVector[i][j] = (cl_float) parameters[j];
3738
    }
3739
    params->setParameterValues(paramVector);
3740
3741
3742
3743
3744
3745
3746
3747
3748
3749
3750
3751
3752
3753
3754
3755
    cl.addForce(new OpenCLCustomExternalForceInfo(force, system.getNumParticles()));

    // Record information for the expressions.

    globalParamNames.resize(force.getNumGlobalParameters());
    globalParamValues.resize(force.getNumGlobalParameters());
    for (int i = 0; i < force.getNumGlobalParameters(); i++) {
        globalParamNames[i] = force.getGlobalParameterName(i);
        globalParamValues[i] = (cl_float) force.getGlobalParameterDefaultValue(i);
    }
    Lepton::ParsedExpression energyExpression = Lepton::Parser::parse(force.getEnergyFunction()).optimize();
    Lepton::ParsedExpression forceExpressionX = energyExpression.differentiate("x").optimize();
    Lepton::ParsedExpression forceExpressionY = energyExpression.differentiate("y").optimize();
    Lepton::ParsedExpression forceExpressionZ = energyExpression.differentiate("z").optimize();
    map<string, Lepton::ParsedExpression> expressions;
    expressions["energy += "] = energyExpression;
3756
3757
3758
    expressions["real dEdX = "] = forceExpressionX;
    expressions["real dEdY = "] = forceExpressionY;
    expressions["real dEdZ = "] = forceExpressionZ;
3759
3760
3761
3762

    // Create the kernels.

    map<string, string> variables;
3763
3764
3765
    variables["x"] = "pos1.x";
    variables["y"] = "pos1.y";
    variables["z"] = "pos1.z";
3766
3767
    for (int i = 0; i < force.getNumPerParticleParameters(); i++) {
        const string& name = force.getPerParticleParameterName(i);
3768
        variables[name] = "particleParams"+params->getParameterSuffix(i);
3769
    }
3770
    if (force.getNumGlobalParameters() > 0) {
3771
        globals = OpenCLArray::create<cl_float>(cl, force.getNumGlobalParameters(), "customExternalGlobals", CL_MEM_READ_ONLY);
3772
3773
3774
3775
        globals->upload(globalParamValues);
        string argName = cl.getBondedUtilities().addArgument(globals->getDeviceBuffer(), "float");
        for (int i = 0; i < force.getNumGlobalParameters(); i++) {
            const string& name = force.getGlobalParameterName(i);
3776
            string value = argName+"["+cl.intToString(i)+"]";
3777
3778
            variables[name] = value;
        }
3779
3780
    }
    stringstream compute;
3781
3782
    for (int i = 0; i < (int) params->getBuffers().size(); i++) {
        const OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i];
3783
3784
        string argName = cl.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType());
        compute<<buffer.getType()<<" particleParams"<<(i+1)<<" = "<<argName<<"[index];\n";
3785
    }
peastman's avatar
peastman committed
3786
3787
    vector<const TabulatedFunction*> functions;
    vector<pair<string, string> > functionNames;
3788
    compute << cl.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp");
3789
    map<string, string> replacements;
3790
    replacements["COMPUTE_FORCE"] = compute.str();
3791
    cl.getBondedUtilities().addInteraction(atoms, cl.replaceStrings(OpenCLKernelSources::customExternalForce, replacements), force.getForceGroup());
3792
3793
}

3794
double OpenCLCalcCustomExternalForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
3795
3796
    if (globals != NULL) {
        bool changed = false;
3797
        for (int i = 0; i < (int) globalParamNames.size(); i++) {
3798
3799
3800
3801
3802
3803
3804
3805
3806
            cl_float value = (cl_float) context.getParameter(globalParamNames[i]);
            if (value != globalParamValues[i])
                changed = true;
            globalParamValues[i] = value;
        }
        if (changed)
            globals->upload(globalParamValues);
    }
    return 0.0;
3807
}
3808

3809
3810
3811
3812
3813
3814
void OpenCLCalcCustomExternalForceKernel::copyParametersToContext(ContextImpl& context, const CustomExternalForce& force) {
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*force.getNumParticles()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*force.getNumParticles()/numContexts;
    if (numParticles != endIndex-startIndex)
        throw OpenMMException("updateParametersInContext: The number of particles has changed");
3815
3816
    if (numParticles == 0)
        return;
3817
3818
3819
3820
3821
3822
3823
3824
3825
3826
3827
3828
3829
3830
3831
3832
3833
3834
3835
    
    // Record the per-particle parameters.
    
    vector<vector<cl_float> > paramVector(numParticles);
    vector<double> parameters;
    for (int i = 0; i < numParticles; i++) {
        int particle;
        force.getParticleParameters(startIndex+i, particle, parameters);
        paramVector[i].resize(parameters.size());
        for (int j = 0; j < (int) parameters.size(); j++)
            paramVector[i][j] = (cl_float) parameters[j];
    }
    params->setParameterValues(paramVector);
    
    // Mark that the current reordering may be invalid.
    
    cl.invalidateMolecules();
}

3836
3837
3838
3839
3840
3841
3842
3843
3844
3845
class OpenCLCustomHbondForceInfo : public OpenCLForceInfo {
public:
    OpenCLCustomHbondForceInfo(int requiredBuffers, const CustomHbondForce& force) : OpenCLForceInfo(requiredBuffers), force(force) {
    }
    bool areParticlesIdentical(int particle1, int particle2) {
        return true;
    }
    int getNumParticleGroups() {
        return force.getNumDonors()+force.getNumAcceptors()+force.getNumExclusions();
    }
Peter Eastman's avatar
Peter Eastman committed
3846
    void getParticlesInGroup(int index, vector<int>& particles) {
3847
3848
3849
3850
        int p1, p2, p3;
        vector<double> parameters;
        if (index < force.getNumDonors()) {
            force.getDonorParameters(index, p1, p2, p3, parameters);
3851
3852
3853
3854
3855
3856
            particles.clear();
            particles.push_back(p1);
            if (p2 > -1)
                particles.push_back(p2);
            if (p3 > -1)
                particles.push_back(p3);
3857
3858
3859
3860
3861
            return;
        }
        index -= force.getNumDonors();
        if (index < force.getNumAcceptors()) {
            force.getAcceptorParameters(index, p1, p2, p3, parameters);
3862
3863
3864
3865
3866
3867
            particles.clear();
            particles.push_back(p1);
            if (p2 > -1)
                particles.push_back(p2);
            if (p3 > -1)
                particles.push_back(p3);
3868
3869
3870
3871
3872
            return;
        }
        index -= force.getNumAcceptors();
        int donor, acceptor;
        force.getExclusionParticles(index, donor, acceptor);
3873
        particles.clear();
3874
        force.getDonorParameters(donor, p1, p2, p3, parameters);
3875
3876
3877
3878
3879
        particles.push_back(p1);
        if (p2 > -1)
            particles.push_back(p2);
        if (p3 > -1)
            particles.push_back(p3);
3880
        force.getAcceptorParameters(acceptor, p1, p2, p3, parameters);
3881
3882
3883
3884
3885
        particles.push_back(p1);
        if (p2 > -1)
            particles.push_back(p2);
        if (p3 > -1)
            particles.push_back(p3);
3886
3887
3888
3889
3890
3891
3892
3893
3894
3895
3896
3897
3898
3899
3900
3901
3902
3903
3904
3905
3906
3907
3908
3909
3910
3911
3912
3913
3914
3915
3916
3917
3918
3919
3920
3921
3922
3923
3924
3925
3926
    }
    bool areGroupsIdentical(int group1, int group2) {
        int p1, p2, p3;
        vector<double> params1, params2;
        if (group1 < force.getNumDonors() && group2 < force.getNumDonors()) {
            force.getDonorParameters(group1, p1, p2, p3, params1);
            force.getDonorParameters(group2, p1, p2, p3, params2);
            return (params1 == params2 && params1 == params2);
        }
        if (group1 < force.getNumDonors() || group2 < force.getNumDonors())
            return false;
        group1 -= force.getNumDonors();
        group2 -= force.getNumDonors();
        if (group1 < force.getNumAcceptors() && group2 < force.getNumAcceptors()) {
            force.getAcceptorParameters(group1, p1, p2, p3, params1);
            force.getAcceptorParameters(group2, p1, p2, p3, params2);
            return (params1 == params2 && params1 == params2);
        }
        if (group1 < force.getNumAcceptors() || group2 < force.getNumAcceptors())
            return false;
        return true;
    }
private:
    const CustomHbondForce& force;
};

OpenCLCalcCustomHbondForceKernel::~OpenCLCalcCustomHbondForceKernel() {
    if (donorParams != NULL)
        delete donorParams;
    if (acceptorParams != NULL)
        delete acceptorParams;
    if (donors != NULL)
        delete donors;
    if (acceptors != NULL)
        delete acceptors;
    if (donorBufferIndices != NULL)
        delete donorBufferIndices;
    if (acceptorBufferIndices != NULL)
        delete acceptorBufferIndices;
    if (globals != NULL)
        delete globals;
3927
3928
3929
3930
    if (donorExclusions != NULL)
        delete donorExclusions;
    if (acceptorExclusions != NULL)
        delete acceptorExclusions;
3931
3932
3933
3934
    for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
        delete tabulatedFunctions[i];
}

3935
3936
3937
3938
3939
3940
3941
3942
3943
3944
3945
3946
static void addDonorAndAcceptorCode(stringstream& computeDonor, stringstream& computeAcceptor, const string& value) {
    computeDonor << value;
    computeAcceptor << value;
}

static void applyDonorAndAcceptorForces(stringstream& applyToDonor, stringstream& applyToAcceptor, int atom, const string& value) {
    string forceNames[] = {"f1", "f2", "f3"};
    if (atom < 3)
        applyToAcceptor << forceNames[atom]<<".xyz += "<<value<<";\n";
    else
        applyToDonor << forceNames[atom-3]<<".xyz += "<<value<<";\n";
}
3947

3948
void OpenCLCalcCustomHbondForceKernel::initialize(const System& system, const CustomHbondForce& force) {
3949
3950
    // Record the lists of donors and acceptors, and the parameters for each one.

3951
3952
3953
3954
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*force.getNumDonors()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*force.getNumDonors()/numContexts;
    numDonors = endIndex-startIndex;
3955
    numAcceptors = force.getNumAcceptors();
3956
3957
    if (numDonors == 0 || numAcceptors == 0)
        return;
3958
    int numParticles = system.getNumParticles();
3959
3960
    donors = OpenCLArray::create<mm_int4>(cl, numDonors, "customHbondDonors");
    acceptors = OpenCLArray::create<mm_int4>(cl, numAcceptors, "customHbondAcceptors");
3961
3962
3963
    donorParams = new OpenCLParameterSet(cl, force.getNumPerDonorParameters(), numDonors, "customHbondDonorParameters");
    acceptorParams = new OpenCLParameterSet(cl, force.getNumPerAcceptorParameters(), numAcceptors, "customHbondAcceptorParameters");
    if (force.getNumGlobalParameters() > 0)
3964
        globals = OpenCLArray::create<cl_float>(cl, force.getNumGlobalParameters(), "customHbondGlobals", CL_MEM_READ_ONLY);
3965
3966
3967
3968
    vector<vector<cl_float> > donorParamVector(numDonors);
    vector<mm_int4> donorVector(numDonors);
    for (int i = 0; i < numDonors; i++) {
        vector<double> parameters;
3969
        force.getDonorParameters(startIndex+i, donorVector[i].x, donorVector[i].y, donorVector[i].z, parameters);
3970
3971
3972
3973
3974
3975
3976
3977
3978
3979
3980
3981
3982
3983
3984
3985
3986
3987
        donorParamVector[i].resize(parameters.size());
        for (int j = 0; j < (int) parameters.size(); j++)
            donorParamVector[i][j] = (cl_float) parameters[j];
    }
    donors->upload(donorVector);
    donorParams->setParameterValues(donorParamVector);
    vector<vector<cl_float> > acceptorParamVector(numAcceptors);
    vector<mm_int4> acceptorVector(numAcceptors);
    for (int i = 0; i < numAcceptors; i++) {
        vector<double> parameters;
        force.getAcceptorParameters(i, acceptorVector[i].x, acceptorVector[i].y, acceptorVector[i].z, parameters);
        acceptorParamVector[i].resize(parameters.size());
        for (int j = 0; j < (int) parameters.size(); j++)
            acceptorParamVector[i][j] = (cl_float) parameters[j];
    }
    acceptors->upload(acceptorVector);
    acceptorParams->setParameterValues(acceptorParamVector);

3988
    // Select an output buffer index for each donor and acceptor.
3989

3990
3991
    donorBufferIndices = OpenCLArray::create<mm_int4>(cl, numDonors, "customHbondDonorBuffers");
    acceptorBufferIndices = OpenCLArray::create<mm_int4>(cl, numAcceptors, "customHbondAcceptorBuffers");
3992
3993
    vector<mm_int4> donorBufferVector(numDonors);
    vector<mm_int4> acceptorBufferVector(numAcceptors);
3994
    vector<int> donorBufferCounter(numParticles, 0);
3995
    for (int i = 0; i < numDonors; i++)
3996
3997
3998
        donorBufferVector[i] = mm_int4(donorVector[i].x > -1 ? donorBufferCounter[donorVector[i].x]++ : 0,
                                       donorVector[i].y > -1 ? donorBufferCounter[donorVector[i].y]++ : 0,
                                       donorVector[i].z > -1 ? donorBufferCounter[donorVector[i].z]++ : 0, 0);
3999
    vector<int> acceptorBufferCounter(numParticles, 0);
4000
    for (int i = 0; i < numAcceptors; i++)
4001
4002
4003
        acceptorBufferVector[i] = mm_int4(acceptorVector[i].x > -1 ? acceptorBufferCounter[acceptorVector[i].x]++ : 0,
                                       acceptorVector[i].y > -1 ? acceptorBufferCounter[acceptorVector[i].y]++ : 0,
                                       acceptorVector[i].z > -1 ? acceptorBufferCounter[acceptorVector[i].z]++ : 0, 0);
4004
4005
    donorBufferIndices->upload(donorBufferVector);
    acceptorBufferIndices->upload(acceptorBufferVector);
4006
4007
4008
4009
4010
4011
    int maxBuffers = 1;
    for (int i = 0; i < (int) donorBufferCounter.size(); i++)
        maxBuffers = max(maxBuffers, donorBufferCounter[i]);
    for (int i = 0; i < (int) acceptorBufferCounter.size(); i++)
        maxBuffers = max(maxBuffers, acceptorBufferCounter[i]);
    cl.addForce(new OpenCLCustomHbondForceInfo(maxBuffers, force));
4012
4013
4014

    // Record exclusions.

4015
4016
    vector<mm_int4> donorExclusionVector(numDonors, mm_int4(-1, -1, -1, -1));
    vector<mm_int4> acceptorExclusionVector(numAcceptors, mm_int4(-1, -1, -1, -1));
4017
4018
4019
    for (int i = 0; i < force.getNumExclusions(); i++) {
        int donor, acceptor;
        force.getExclusionParticles(i, donor, acceptor);
4020
4021
4022
        if (donor < startIndex || donor >= endIndex)
            continue;
        donor -= startIndex;
4023
4024
4025
4026
4027
4028
4029
4030
4031
4032
4033
4034
4035
4036
4037
4038
4039
4040
4041
4042
        if (donorExclusionVector[donor].x == -1)
            donorExclusionVector[donor].x = acceptor;
        else if (donorExclusionVector[donor].y == -1)
            donorExclusionVector[donor].y = acceptor;
        else if (donorExclusionVector[donor].z == -1)
            donorExclusionVector[donor].z = acceptor;
        else if (donorExclusionVector[donor].w == -1)
            donorExclusionVector[donor].w = acceptor;
        else
            throw OpenMMException("CustomHbondForce: OpenCLPlatform does not support more than four exclusions per donor");
        if (acceptorExclusionVector[acceptor].x == -1)
            acceptorExclusionVector[acceptor].x = donor;
        else if (acceptorExclusionVector[acceptor].y == -1)
            acceptorExclusionVector[acceptor].y = donor;
        else if (acceptorExclusionVector[acceptor].z == -1)
            acceptorExclusionVector[acceptor].z = donor;
        else if (acceptorExclusionVector[acceptor].w == -1)
            acceptorExclusionVector[acceptor].w = donor;
        else
            throw OpenMMException("CustomHbondForce: OpenCLPlatform does not support more than four exclusions per acceptor");
4043
    }
4044
4045
    donorExclusions = OpenCLArray::create<mm_int4>(cl, numDonors, "customHbondDonorExclusions");
    acceptorExclusions = OpenCLArray::create<mm_int4>(cl, numAcceptors, "customHbondAcceptorExclusions");
4046
4047
    donorExclusions->upload(donorExclusionVector);
    acceptorExclusions->upload(acceptorExclusionVector);
4048
4049
4050
4051
4052

    // Record the tabulated functions.

    map<string, Lepton::CustomFunction*> functions;
    vector<pair<string, string> > functionDefinitions;
4053
    vector<const TabulatedFunction*> functionList;
4054
4055
    stringstream tableArgs;
    for (int i = 0; i < force.getNumFunctions(); i++) {
4056
4057
        functionList.push_back(&force.getTabulatedFunction(i));
        string name = force.getTabulatedFunctionName(i);
4058
        string arrayName = "table"+cl.intToString(i);
4059
        functionDefinitions.push_back(make_pair(name, arrayName));
4060
        functions[name] = cl.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
peastman's avatar
peastman committed
4061
        int width;
4062
        vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
4063
        tabulatedFunctions.push_back(OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction"));
4064
        tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
peastman's avatar
peastman committed
4065
4066
4067
4068
        tableArgs << ", __global const float";
        if (width > 1)
            tableArgs << width;
        tableArgs << "* restrict " << arrayName;
4069
4070
    }

4071
    // Record information about parameters.
4072
4073
4074
4075
4076
4077
4078
4079
4080
4081
4082
4083
4084
4085
4086
4087
4088
4089
4090
4091

    globalParamNames.resize(force.getNumGlobalParameters());
    globalParamValues.resize(force.getNumGlobalParameters());
    for (int i = 0; i < force.getNumGlobalParameters(); i++) {
        globalParamNames[i] = force.getGlobalParameterName(i);
        globalParamValues[i] = (cl_float) force.getGlobalParameterDefaultValue(i);
    }
    if (globals != NULL)
        globals->upload(globalParamValues);
    map<string, string> variables;
    for (int i = 0; i < force.getNumPerDonorParameters(); i++) {
        const string& name = force.getPerDonorParameterName(i);
        variables[name] = "donorParams"+donorParams->getParameterSuffix(i);
    }
    for (int i = 0; i < force.getNumPerAcceptorParameters(); i++) {
        const string& name = force.getPerAcceptorParameterName(i);
        variables[name] = "acceptorParams"+acceptorParams->getParameterSuffix(i);
    }
    for (int i = 0; i < force.getNumGlobalParameters(); i++) {
        const string& name = force.getGlobalParameterName(i);
4092
        variables[name] = "globals["+cl.intToString(i)+"]";
4093
    }
4094
4095
4096
4097
4098
4099
4100
4101
4102
4103
4104
4105
4106
4107
4108
4109
4110
4111
4112

    // Now to generate the kernel.  First, it needs to calculate all distances, angles,
    // and dihedrals the expression depends on.

    map<string, vector<int> > distances;
    map<string, vector<int> > angles;
    map<string, vector<int> > dihedrals;
    Lepton::ParsedExpression energyExpression = CustomHbondForceImpl::prepareExpression(force, functions, distances, angles, dihedrals);
    map<string, Lepton::ParsedExpression> forceExpressions;
    set<string> computedDeltas;
    computedDeltas.insert("D1A1");
    string atomNames[] = {"A1", "A2", "A3", "D1", "D2", "D3"};
    string atomNamesLower[] = {"a1", "a2", "a3", "d1", "d2", "d3"};
    stringstream computeDonor, computeAcceptor, extraArgs;
    int index = 0;
    for (map<string, vector<int> >::const_iterator iter = distances.begin(); iter != distances.end(); ++iter, ++index) {
        const vector<int>& atoms = iter->second;
        string deltaName = atomNames[atoms[0]]+atomNames[atoms[1]];
        if (computedDeltas.count(deltaName) == 0) {
4113
            addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 delta"+deltaName+" = delta("+atomNamesLower[atoms[0]]+", "+atomNamesLower[atoms[1]]+");\n");
4114
4115
            computedDeltas.insert(deltaName);
        }
4116
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real r_"+deltaName+" = SQRT(delta"+deltaName+".w);\n");
4117
        variables[iter->first] = "r_"+deltaName;
4118
        forceExpressions["real dEdDistance"+cl.intToString(index)+" = "] = energyExpression.differentiate(iter->first).optimize();
4119
4120
4121
4122
4123
4124
4125
4126
    }
    index = 0;
    for (map<string, vector<int> >::const_iterator iter = angles.begin(); iter != angles.end(); ++iter, ++index) {
        const vector<int>& atoms = iter->second;
        string deltaName1 = atomNames[atoms[1]]+atomNames[atoms[0]];
        string deltaName2 = atomNames[atoms[1]]+atomNames[atoms[2]];
        string angleName = "angle_"+atomNames[atoms[0]]+atomNames[atoms[1]]+atomNames[atoms[2]];
        if (computedDeltas.count(deltaName1) == 0) {
4127
            addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 delta"+deltaName1+" = delta("+atomNamesLower[atoms[1]]+", "+atomNamesLower[atoms[0]]+");\n");
4128
4129
4130
            computedDeltas.insert(deltaName1);
        }
        if (computedDeltas.count(deltaName2) == 0) {
4131
            addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 delta"+deltaName2+" = delta("+atomNamesLower[atoms[1]]+", "+atomNamesLower[atoms[2]]+");\n");
4132
4133
            computedDeltas.insert(deltaName2);
        }
4134
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real "+angleName+" = computeAngle(delta"+deltaName1+", delta"+deltaName2+");\n");
4135
        variables[iter->first] = angleName;
4136
        forceExpressions["real dEdAngle"+cl.intToString(index)+" = "] = energyExpression.differentiate(iter->first).optimize();
4137
4138
4139
4140
4141
4142
4143
4144
4145
4146
4147
    }
    index = 0;
    for (map<string, vector<int> >::const_iterator iter = dihedrals.begin(); iter != dihedrals.end(); ++iter, ++index) {
        const vector<int>& atoms = iter->second;
        string deltaName1 = atomNames[atoms[0]]+atomNames[atoms[1]];
        string deltaName2 = atomNames[atoms[2]]+atomNames[atoms[1]];
        string deltaName3 = atomNames[atoms[2]]+atomNames[atoms[3]];
        string crossName1 = "cross_"+deltaName1+"_"+deltaName2;
        string crossName2 = "cross_"+deltaName2+"_"+deltaName3;
        string dihedralName = "dihedral_"+atomNames[atoms[0]]+atomNames[atoms[1]]+atomNames[atoms[2]]+atomNames[atoms[3]];
        if (computedDeltas.count(deltaName1) == 0) {
4148
            addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 delta"+deltaName1+" = delta("+atomNamesLower[atoms[0]]+", "+atomNamesLower[atoms[1]]+");\n");
4149
4150
4151
            computedDeltas.insert(deltaName1);
        }
        if (computedDeltas.count(deltaName2) == 0) {
4152
            addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 delta"+deltaName2+" = delta("+atomNamesLower[atoms[2]]+", "+atomNamesLower[atoms[1]]+");\n");
4153
4154
4155
            computedDeltas.insert(deltaName2);
        }
        if (computedDeltas.count(deltaName3) == 0) {
4156
            addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 delta"+deltaName3+" = delta("+atomNamesLower[atoms[2]]+", "+atomNamesLower[atoms[3]]+");\n");
4157
4158
            computedDeltas.insert(deltaName3);
        }
4159
4160
4161
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 "+crossName1+" = computeCross(delta"+deltaName1+", delta"+deltaName2+");\n");
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 "+crossName2+" = computeCross(delta"+deltaName2+", delta"+deltaName3+");\n");
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real "+dihedralName+" = computeAngle("+crossName1+", "+crossName2+");\n");
4162
4163
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, dihedralName+" *= (delta"+deltaName1+".x*"+crossName2+".x + delta"+deltaName1+".y*"+crossName2+".y + delta"+deltaName1+".z*"+crossName2+".z < 0 ? -1 : 1);\n");
        variables[iter->first] = dihedralName;
4164
        forceExpressions["real dEdDihedral"+cl.intToString(index)+" = "] = energyExpression.differentiate(iter->first).optimize();
4165
4166
4167
4168
    }

    // Next it needs to load parameters from global memory.

4169
    if (force.getNumGlobalParameters() > 0)
4170
        extraArgs << ", __global const float* restrict globals";
4171
4172
    for (int i = 0; i < (int) donorParams->getBuffers().size(); i++) {
        const OpenCLNonbondedUtilities::ParameterInfo& buffer = donorParams->getBuffers()[i];
4173
        extraArgs << ", __global const "+buffer.getType()+"* restrict donor"+buffer.getName();
4174
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, buffer.getType()+" donorParams"+cl.intToString(i+1)+" = donor"+buffer.getName()+"[index];\n");
4175
4176
4177
    }
    for (int i = 0; i < (int) acceptorParams->getBuffers().size(); i++) {
        const OpenCLNonbondedUtilities::ParameterInfo& buffer = acceptorParams->getBuffers()[i];
4178
        extraArgs << ", __global const "+buffer.getType()+"* restrict acceptor"+buffer.getName();
4179
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, buffer.getType()+" acceptorParams"+cl.intToString(i+1)+" = acceptor"+buffer.getName()+"[index];\n");
4180
    }
4181
4182
4183

    // Now evaluate the expressions.

4184
    computeAcceptor << cl.getExpressionUtilities().createExpressions(forceExpressions, variables, functionList, functionDefinitions, "temp");
4185
    forceExpressions["energy += "] = energyExpression;
4186
    computeDonor << cl.getExpressionUtilities().createExpressions(forceExpressions, variables, functionList, functionDefinitions, "temp");
4187
4188
4189
4190
4191
4192
4193

    // Finally, apply forces to atoms.

    index = 0;
    for (map<string, vector<int> >::const_iterator iter = distances.begin(); iter != distances.end(); ++iter, ++index) {
        const vector<int>& atoms = iter->second;
        string deltaName = atomNames[atoms[0]]+atomNames[atoms[1]];
4194
        string value = "(dEdDistance"+cl.intToString(index)+"/r_"+deltaName+")*delta"+deltaName+".xyz";
4195
4196
4197
4198
4199
4200
4201
4202
4203
        applyDonorAndAcceptorForces(computeDonor, computeAcceptor, atoms[0], "-"+value);
        applyDonorAndAcceptorForces(computeDonor, computeAcceptor, atoms[1], value);
    }
    index = 0;
    for (map<string, vector<int> >::const_iterator iter = angles.begin(); iter != angles.end(); ++iter, ++index) {
        const vector<int>& atoms = iter->second;
        string deltaName1 = atomNames[atoms[1]]+atomNames[atoms[0]];
        string deltaName2 = atomNames[atoms[1]]+atomNames[atoms[2]];
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "{\n");
4204
4205
4206
4207
4208
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 crossProd = cross(delta"+deltaName2+", delta"+deltaName1+");\n");
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real lengthCross = max(length(crossProd), (real) 1e-6f);\n");
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 deltaCross0 = -cross(delta"+deltaName1+", crossProd)*dEdAngle"+cl.intToString(index)+"/(delta"+deltaName1+".w*lengthCross);\n");
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 deltaCross2 = cross(delta"+deltaName2+", crossProd)*dEdAngle"+cl.intToString(index)+"/(delta"+deltaName2+".w*lengthCross);\n");
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 deltaCross1 = -(deltaCross0+deltaCross2);\n");
4209
4210
4211
4212
4213
4214
4215
4216
4217
4218
4219
4220
4221
4222
        applyDonorAndAcceptorForces(computeDonor, computeAcceptor, atoms[0], "deltaCross0.xyz");
        applyDonorAndAcceptorForces(computeDonor, computeAcceptor, atoms[1], "deltaCross1.xyz");
        applyDonorAndAcceptorForces(computeDonor, computeAcceptor, atoms[2], "deltaCross2.xyz");
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "}\n");
    }
    index = 0;
    for (map<string, vector<int> >::const_iterator iter = dihedrals.begin(); iter != dihedrals.end(); ++iter, ++index) {
        const vector<int>& atoms = iter->second;
        string deltaName1 = atomNames[atoms[0]]+atomNames[atoms[1]];
        string deltaName2 = atomNames[atoms[2]]+atomNames[atoms[1]];
        string deltaName3 = atomNames[atoms[2]]+atomNames[atoms[3]];
        string crossName1 = "cross_"+deltaName1+"_"+deltaName2;
        string crossName2 = "cross_"+deltaName2+"_"+deltaName3;
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "{\n");
4223
4224
4225
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real r = SQRT(delta"+deltaName2+".w);\n");
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 ff;\n");
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "ff.x = (-dEdDihedral"+cl.intToString(index)+"*r)/"+crossName1+".w;\n");
4226
4227
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "ff.y = (delta"+deltaName1+".x*delta"+deltaName2+".x + delta"+deltaName1+".y*delta"+deltaName2+".y + delta"+deltaName1+".z*delta"+deltaName2+".z)/delta"+deltaName2+".w;\n");
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "ff.z = (delta"+deltaName3+".x*delta"+deltaName2+".x + delta"+deltaName3+".y*delta"+deltaName2+".y + delta"+deltaName3+".z*delta"+deltaName2+".z)/delta"+deltaName2+".w;\n");
4228
4229
4230
4231
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "ff.w = (dEdDihedral"+cl.intToString(index)+"*r)/"+crossName2+".w;\n");
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 internalF0 = ff.x*"+crossName1+";\n");
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 internalF3 = ff.w*"+crossName2+";\n");
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "real4 s = ff.y*internalF0 - ff.z*internalF3;\n");
4232
4233
4234
4235
4236
4237
4238
4239
4240
        applyDonorAndAcceptorForces(computeDonor, computeAcceptor, atoms[0], "internalF0.xyz");
        applyDonorAndAcceptorForces(computeDonor, computeAcceptor, atoms[1], "s.xyz-internalF0.xyz");
        applyDonorAndAcceptorForces(computeDonor, computeAcceptor, atoms[2], "-s.xyz-internalF3.xyz");
        applyDonorAndAcceptorForces(computeDonor, computeAcceptor, atoms[3], "internalF3.xyz");
        addDonorAndAcceptorCode(computeDonor, computeAcceptor, "}\n");
    }

    // Generate the kernels.

4241
    map<string, string> replacements;
4242
4243
    replacements["COMPUTE_DONOR_FORCE"] = computeDonor.str();
    replacements["COMPUTE_ACCEPTOR_FORCE"] = computeAcceptor.str();
4244
4245
    replacements["PARAMETER_ARGUMENTS"] = extraArgs.str()+tableArgs.str();
    map<string, string> defines;
4246
4247
4248
4249
    defines["PADDED_NUM_ATOMS"] = cl.intToString(cl.getPaddedNumAtoms());
    defines["NUM_DONORS"] = cl.intToString(numDonors);
    defines["NUM_ACCEPTORS"] = cl.intToString(numAcceptors);
    defines["PI"] = cl.doubleToString(M_PI);
4250
4251
    if (force.getNonbondedMethod() != CustomHbondForce::NoCutoff) {
        defines["USE_CUTOFF"] = "1";
4252
        defines["CUTOFF_SQUARED"] = cl.doubleToString(force.getCutoffDistance()*force.getCutoffDistance());
4253
4254
4255
    }
    if (force.getNonbondedMethod() != CustomHbondForce::NoCutoff && force.getNonbondedMethod() != CustomHbondForce::CutoffNonPeriodic)
        defines["USE_PERIODIC"] = "1";
4256
4257
    if (force.getNumExclusions() > 0)
        defines["USE_EXCLUSIONS"] = "1";
4258
    cl::Program program = cl.createProgram(cl.replaceStrings(OpenCLKernelSources::customHbondForce, replacements), defines);
4259
4260
    donorKernel = cl::Kernel(program, "computeDonorForces");
    acceptorKernel = cl::Kernel(program, "computeAcceptorForces");
4261
4262
}

4263
double OpenCLCalcCustomHbondForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
4264
4265
    if (numDonors == 0 || numAcceptors == 0)
        return 0.0;
4266
4267
4268
4269
4270
4271
4272
4273
4274
4275
4276
4277
4278
4279
    if (globals != NULL) {
        bool changed = false;
        for (int i = 0; i < (int) globalParamNames.size(); i++) {
            cl_float value = (cl_float) context.getParameter(globalParamNames[i]);
            if (value != globalParamValues[i])
                changed = true;
            globalParamValues[i] = value;
        }
        if (changed)
            globals->upload(globalParamValues);
    }
    if (!hasInitializedKernel) {
        hasInitializedKernel = true;
        int index = 0;
4280
4281
4282
        donorKernel.setArg<cl::Buffer>(index++, cl.getForceBuffers().getDeviceBuffer());
        donorKernel.setArg<cl::Buffer>(index++, cl.getEnergyBuffer().getDeviceBuffer());
        donorKernel.setArg<cl::Buffer>(index++, cl.getPosq().getDeviceBuffer());
4283
        donorKernel.setArg<cl::Buffer>(index++, donorExclusions->getDeviceBuffer());
4284
4285
4286
4287
        donorKernel.setArg<cl::Buffer>(index++, donors->getDeviceBuffer());
        donorKernel.setArg<cl::Buffer>(index++, acceptors->getDeviceBuffer());
        donorKernel.setArg<cl::Buffer>(index++, donorBufferIndices->getDeviceBuffer());
        donorKernel.setArg(index++, 3*OpenCLContext::ThreadBlockSize*sizeof(mm_float4), NULL);
4288
        index += 5; // Periodic box size arguments are set when the kernel is executed.
4289
        if (globals != NULL)
4290
            donorKernel.setArg<cl::Buffer>(index++, globals->getDeviceBuffer());
4291
4292
        for (int i = 0; i < (int) donorParams->getBuffers().size(); i++) {
            const OpenCLNonbondedUtilities::ParameterInfo& buffer = donorParams->getBuffers()[i];
4293
            donorKernel.setArg<cl::Memory>(index++, buffer.getMemory());
4294
4295
4296
        }
        for (int i = 0; i < (int) acceptorParams->getBuffers().size(); i++) {
            const OpenCLNonbondedUtilities::ParameterInfo& buffer = acceptorParams->getBuffers()[i];
4297
            donorKernel.setArg<cl::Memory>(index++, buffer.getMemory());
4298
        }
4299
4300
        for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
            donorKernel.setArg<cl::Buffer>(index++, tabulatedFunctions[i]->getDeviceBuffer());
4301
4302
4303
4304
        index = 0;
        acceptorKernel.setArg<cl::Buffer>(index++, cl.getForceBuffers().getDeviceBuffer());
        acceptorKernel.setArg<cl::Buffer>(index++, cl.getEnergyBuffer().getDeviceBuffer());
        acceptorKernel.setArg<cl::Buffer>(index++, cl.getPosq().getDeviceBuffer());
4305
        acceptorKernel.setArg<cl::Buffer>(index++, acceptorExclusions->getDeviceBuffer());
4306
4307
4308
4309
        acceptorKernel.setArg<cl::Buffer>(index++, donors->getDeviceBuffer());
        acceptorKernel.setArg<cl::Buffer>(index++, acceptors->getDeviceBuffer());
        acceptorKernel.setArg<cl::Buffer>(index++, acceptorBufferIndices->getDeviceBuffer());
        acceptorKernel.setArg(index++, 3*OpenCLContext::ThreadBlockSize*sizeof(mm_float4), NULL);
4310
        index += 5; // Periodic box size arguments are set when the kernel is executed.
4311
4312
4313
4314
        if (globals != NULL)
            acceptorKernel.setArg<cl::Buffer>(index++, globals->getDeviceBuffer());
        for (int i = 0; i < (int) donorParams->getBuffers().size(); i++) {
            const OpenCLNonbondedUtilities::ParameterInfo& buffer = donorParams->getBuffers()[i];
4315
            acceptorKernel.setArg<cl::Memory>(index++, buffer.getMemory());
4316
4317
4318
        }
        for (int i = 0; i < (int) acceptorParams->getBuffers().size(); i++) {
            const OpenCLNonbondedUtilities::ParameterInfo& buffer = acceptorParams->getBuffers()[i];
4319
            acceptorKernel.setArg<cl::Memory>(index++, buffer.getMemory());
4320
        }
4321
4322
        for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
            acceptorKernel.setArg<cl::Buffer>(index++, tabulatedFunctions[i]->getDeviceBuffer());
4323
    }
4324
    setPeriodicBoxArgs(cl, donorKernel, 8);
Peter Eastman's avatar
Peter Eastman committed
4325
    cl.executeKernel(donorKernel, max(numDonors, numAcceptors));
4326
    setPeriodicBoxArgs(cl, acceptorKernel, 8);
Peter Eastman's avatar
Peter Eastman committed
4327
    cl.executeKernel(acceptorKernel, max(numDonors, numAcceptors));
4328
4329
4330
    return 0.0;
}

4331
4332
4333
4334
4335
4336
4337
4338
4339
4340
4341
void OpenCLCalcCustomHbondForceKernel::copyParametersToContext(ContextImpl& context, const CustomHbondForce& force) {
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*force.getNumDonors()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*force.getNumDonors()/numContexts;
    if (numDonors != endIndex-startIndex)
        throw OpenMMException("updateParametersInContext: The number of donors has changed");
    if (numAcceptors != force.getNumAcceptors())
        throw OpenMMException("updateParametersInContext: The number of acceptors has changed");
    
    // Record the per-donor parameters.
    
4342
4343
4344
4345
4346
4347
4348
4349
4350
4351
4352
    if (numDonors > 0) {
        vector<vector<cl_float> > donorParamVector(numDonors);
        vector<double> parameters;
        for (int i = 0; i < numDonors; i++) {
            int d1, d2, d3;
            force.getDonorParameters(startIndex+i, d1, d2, d3, parameters);
            donorParamVector[i].resize(parameters.size());
            for (int j = 0; j < (int) parameters.size(); j++)
                donorParamVector[i][j] = (cl_float) parameters[j];
        }
        donorParams->setParameterValues(donorParamVector);
4353
4354
4355
4356
    }
    
    // Record the per-acceptor parameters.
    
4357
4358
4359
4360
4361
4362
4363
4364
4365
4366
4367
    if (numAcceptors > 0) {
        vector<vector<cl_float> > acceptorParamVector(numAcceptors);
        vector<double> parameters;
        for (int i = 0; i < numAcceptors; i++) {
            int a1, a2, a3;
            force.getAcceptorParameters(i, a1, a2, a3, parameters);
            acceptorParamVector[i].resize(parameters.size());
            for (int j = 0; j < (int) parameters.size(); j++)
                acceptorParamVector[i][j] = (cl_float) parameters[j];
        }
        acceptorParams->setParameterValues(acceptorParamVector);
4368
4369
4370
4371
4372
4373
4374
    }
    
    // Mark that the current reordering may be invalid.
    
    cl.invalidateMolecules();
}

4375
4376
class OpenCLCustomCompoundBondForceInfo : public OpenCLForceInfo {
public:
Peter Eastman's avatar
Peter Eastman committed
4377
    OpenCLCustomCompoundBondForceInfo(const CustomCompoundBondForce& force) : OpenCLForceInfo(0), force(force) {
4378
4379
4380
4381
4382
4383
4384
4385
4386
4387
4388
4389
4390
4391
4392
4393
4394
4395
4396
4397
4398
4399
4400
4401
4402
4403
4404
4405
4406
4407
4408
4409
4410
4411
4412
4413
4414
4415
4416
4417
4418
4419
4420
4421
4422
4423
4424
4425
4426
4427
4428
4429
4430
4431
4432
4433
    }
    int getNumParticleGroups() {
        return force.getNumBonds();
    }
    void getParticlesInGroup(int index, vector<int>& particles) {
        vector<double> parameters;
        force.getBondParameters(index, particles, parameters);
    }
    bool areGroupsIdentical(int group1, int group2) {
        vector<int> particles;
        vector<double> parameters1, parameters2;
        force.getBondParameters(group1, particles, parameters1);
        force.getBondParameters(group2, particles, parameters2);
        for (int i = 0; i < (int) parameters1.size(); i++)
            if (parameters1[i] != parameters2[i])
                return false;
        return true;
    }
private:
    const CustomCompoundBondForce& force;
};

OpenCLCalcCustomCompoundBondForceKernel::~OpenCLCalcCustomCompoundBondForceKernel() {
    if (params != NULL)
        delete params;
    if (globals != NULL)
        delete globals;
    for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
        delete tabulatedFunctions[i];
}

void OpenCLCalcCustomCompoundBondForceKernel::initialize(const System& system, const CustomCompoundBondForce& force) {
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*force.getNumBonds()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*force.getNumBonds()/numContexts;
    numBonds = endIndex-startIndex;
    if (numBonds == 0)
        return;
    int particlesPerBond = force.getNumParticlesPerBond();
    vector<vector<int> > atoms(numBonds, vector<int>(particlesPerBond));
    params = new OpenCLParameterSet(cl, force.getNumPerBondParameters(), numBonds, "customCompoundBondParams");
    vector<vector<cl_float> > paramVector(numBonds);
    for (int i = 0; i < numBonds; i++) {
        vector<double> parameters;
        force.getBondParameters(startIndex+i, atoms[i], parameters);
        paramVector[i].resize(parameters.size());
        for (int j = 0; j < (int) parameters.size(); j++)
            paramVector[i][j] = (cl_float) parameters[j];
    }
    params->setParameterValues(paramVector);
    cl.addForce(new OpenCLCustomCompoundBondForceInfo(force));

    // Record the tabulated functions.

    map<string, Lepton::CustomFunction*> functions;
    vector<pair<string, string> > functionDefinitions;
4434
    vector<const TabulatedFunction*> functionList;
4435
4436
    stringstream tableArgs;
    for (int i = 0; i < force.getNumFunctions(); i++) {
4437
4438
4439
        functionList.push_back(&force.getTabulatedFunction(i));
        string name = force.getTabulatedFunctionName(i);
        functions[name] = cl.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
peastman's avatar
peastman committed
4440
        int width;
4441
        vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
4442
        OpenCLArray* array = OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction");
4443
4444
        tabulatedFunctions.push_back(array);
        array->upload(f);
peastman's avatar
peastman committed
4445
        string arrayName = cl.getBondedUtilities().addArgument(array->getDeviceBuffer(), width == 1 ? "float" : "float"+cl.intToString(width));
4446
4447
4448
4449
4450
4451
4452
4453
4454
4455
4456
4457
4458
        functionDefinitions.push_back(make_pair(name, arrayName));
    }
    
    // Record information about parameters.

    globalParamNames.resize(force.getNumGlobalParameters());
    globalParamValues.resize(force.getNumGlobalParameters());
    for (int i = 0; i < force.getNumGlobalParameters(); i++) {
        globalParamNames[i] = force.getGlobalParameterName(i);
        globalParamValues[i] = (cl_float) force.getGlobalParameterDefaultValue(i);
    }
    map<string, string> variables;
    for (int i = 0; i < particlesPerBond; i++) {
4459
        string index = cl.intToString(i+1);
4460
4461
4462
4463
4464
4465
4466
4467
4468
        variables["x"+index] = "pos"+index+".x";
        variables["y"+index] = "pos"+index+".y";
        variables["z"+index] = "pos"+index+".z";
    }
    for (int i = 0; i < force.getNumPerBondParameters(); i++) {
        const string& name = force.getPerBondParameterName(i);
        variables[name] = "bondParams"+params->getParameterSuffix(i);
    }
    if (force.getNumGlobalParameters() > 0) {
4469
        globals = OpenCLArray::create<cl_float>(cl, force.getNumGlobalParameters(), "customCompoundBondGlobals", CL_MEM_READ_ONLY);
4470
4471
4472
4473
        globals->upload(globalParamValues);
        string argName = cl.getBondedUtilities().addArgument(globals->getDeviceBuffer(), "float");
        for (int i = 0; i < force.getNumGlobalParameters(); i++) {
            const string& name = force.getGlobalParameterName(i);
4474
            string value = argName+"["+cl.intToString(i)+"]";
4475
4476
4477
4478
4479
4480
4481
4482
4483
4484
4485
4486
4487
4488
4489
            variables[name] = value;
        }
    }

    // Now to generate the kernel.  First, it needs to calculate all distances, angles,
    // and dihedrals the expression depends on.

    map<string, vector<int> > distances;
    map<string, vector<int> > angles;
    map<string, vector<int> > dihedrals;
    Lepton::ParsedExpression energyExpression = CustomCompoundBondForceImpl::prepareExpression(force, functions, distances, angles, dihedrals);
    map<string, Lepton::ParsedExpression> forceExpressions;
    set<string> computedDeltas;
    vector<string> atomNames, posNames;
    for (int i = 0; i < particlesPerBond; i++) {
4490
        string index = cl.intToString(i+1);
4491
4492
4493
4494
4495
4496
4497
4498
4499
        atomNames.push_back("P"+index);
        posNames.push_back("pos"+index);
    }
    stringstream compute;
    int index = 0;
    for (map<string, vector<int> >::const_iterator iter = distances.begin(); iter != distances.end(); ++iter, ++index) {
        const vector<int>& atoms = iter->second;
        string deltaName = atomNames[atoms[0]]+atomNames[atoms[1]];
        if (computedDeltas.count(deltaName) == 0) {
4500
            compute<<"real4 delta"<<deltaName<<" = ccb_delta("<<posNames[atoms[0]]<<", "<<posNames[atoms[1]]<<");\n";
4501
4502
            computedDeltas.insert(deltaName);
        }
4503
        compute<<"real r_"<<deltaName<<" = sqrt(delta"<<deltaName<<".w);\n";
4504
        variables[iter->first] = "r_"+deltaName;
4505
        forceExpressions["real dEdDistance"+cl.intToString(index)+" = "] = energyExpression.differentiate(iter->first).optimize();
4506
4507
4508
4509
4510
4511
4512
4513
    }
    index = 0;
    for (map<string, vector<int> >::const_iterator iter = angles.begin(); iter != angles.end(); ++iter, ++index) {
        const vector<int>& atoms = iter->second;
        string deltaName1 = atomNames[atoms[1]]+atomNames[atoms[0]];
        string deltaName2 = atomNames[atoms[1]]+atomNames[atoms[2]];
        string angleName = "angle_"+atomNames[atoms[0]]+atomNames[atoms[1]]+atomNames[atoms[2]];
        if (computedDeltas.count(deltaName1) == 0) {
4514
            compute<<"real4 delta"<<deltaName1<<" = ccb_delta("<<posNames[atoms[1]]<<", "<<posNames[atoms[0]]<<");\n";
4515
4516
4517
            computedDeltas.insert(deltaName1);
        }
        if (computedDeltas.count(deltaName2) == 0) {
4518
            compute<<"real4 delta"<<deltaName2<<" = ccb_delta("<<posNames[atoms[1]]<<", "<<posNames[atoms[2]]<<");\n";
4519
4520
            computedDeltas.insert(deltaName2);
        }
4521
        compute<<"real "<<angleName<<" = ccb_computeAngle(delta"<<deltaName1<<", delta"<<deltaName2<<");\n";
4522
        variables[iter->first] = angleName;
4523
        forceExpressions["real dEdAngle"+cl.intToString(index)+" = "] = energyExpression.differentiate(iter->first).optimize();
4524
4525
4526
4527
4528
4529
4530
4531
4532
4533
4534
    }
    index = 0;
    for (map<string, vector<int> >::const_iterator iter = dihedrals.begin(); iter != dihedrals.end(); ++iter, ++index) {
        const vector<int>& atoms = iter->second;
        string deltaName1 = atomNames[atoms[0]]+atomNames[atoms[1]];
        string deltaName2 = atomNames[atoms[2]]+atomNames[atoms[1]];
        string deltaName3 = atomNames[atoms[2]]+atomNames[atoms[3]];
        string crossName1 = "cross_"+deltaName1+"_"+deltaName2;
        string crossName2 = "cross_"+deltaName2+"_"+deltaName3;
        string dihedralName = "dihedral_"+atomNames[atoms[0]]+atomNames[atoms[1]]+atomNames[atoms[2]]+atomNames[atoms[3]];
        if (computedDeltas.count(deltaName1) == 0) {
4535
            compute<<"real4 delta"<<deltaName1<<" = ccb_delta("<<posNames[atoms[0]]<<", "<<posNames[atoms[1]]<<");\n";
4536
4537
4538
            computedDeltas.insert(deltaName1);
        }
        if (computedDeltas.count(deltaName2) == 0) {
4539
            compute<<"real4 delta"<<deltaName2<<" = ccb_delta("<<posNames[atoms[2]]<<", "<<posNames[atoms[1]]<<");\n";
4540
4541
4542
            computedDeltas.insert(deltaName2);
        }
        if (computedDeltas.count(deltaName3) == 0) {
4543
            compute<<"real4 delta"<<deltaName3<<" = ccb_delta("<<posNames[atoms[2]]<<", "<<posNames[atoms[3]]<<");\n";
4544
4545
            computedDeltas.insert(deltaName3);
        }
4546
4547
4548
        compute<<"real4 "<<crossName1<<" = ccb_computeCross(delta"<<deltaName1<<", delta"<<deltaName2<<");\n";
        compute<<"real4 "<<crossName2<<" = ccb_computeCross(delta"<<deltaName2<<", delta"<<deltaName3<<");\n";
        compute<<"real "<<dihedralName<<" = ccb_computeAngle("<<crossName1<<", "<<crossName2<<");\n";
4549
4550
        compute<<dihedralName<<" *= (delta"<<deltaName1<<".x*"<<crossName2<<".x + delta"<<deltaName1<<".y*"<<crossName2<<".y + delta"<<deltaName1<<".z*"<<crossName2<<".z < 0 ? -1 : 1);\n";
        variables[iter->first] = dihedralName;
4551
        forceExpressions["real dEdDihedral"+cl.intToString(index)+" = "] = energyExpression.differentiate(iter->first).optimize();
4552
4553
4554
4555
4556
4557
4558
4559
4560
4561
    }

    // Now evaluate the expressions.

    for (int i = 0; i < (int) params->getBuffers().size(); i++) {
        const OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i];
        string argName = cl.getBondedUtilities().addArgument(buffer.getMemory(), buffer.getType());
        compute<<buffer.getType()<<" bondParams"<<(i+1)<<" = "<<argName<<"[index];\n";
    }
    forceExpressions["energy += "] = energyExpression;
4562
    compute << cl.getExpressionUtilities().createExpressions(forceExpressions, variables, functionList, functionDefinitions, "temp");
4563
4564
4565
4566
4567

    // Finally, apply forces to atoms.

    vector<string> forceNames;
    for (int i = 0; i < particlesPerBond; i++) {
4568
        string istr = cl.intToString(i+1);
4569
4570
        string forceName = "force"+istr;
        forceNames.push_back(forceName);
4571
        compute<<"real4 "<<forceName<<" = (real4) 0;\n";
4572
4573
4574
4575
4576
4577
4578
4579
4580
4581
4582
4583
        compute<<"{\n";
        Lepton::ParsedExpression forceExpressionX = energyExpression.differentiate("x"+istr).optimize();
        Lepton::ParsedExpression forceExpressionY = energyExpression.differentiate("y"+istr).optimize();
        Lepton::ParsedExpression forceExpressionZ = energyExpression.differentiate("z"+istr).optimize();
        map<string, Lepton::ParsedExpression> expressions;
        if (!isZeroExpression(forceExpressionX))
            expressions[forceName+".x -= "] = forceExpressionX;
        if (!isZeroExpression(forceExpressionY))
            expressions[forceName+".y -= "] = forceExpressionY;
        if (!isZeroExpression(forceExpressionZ))
            expressions[forceName+".z -= "] = forceExpressionZ;
        if (expressions.size() > 0)
4584
            compute<<cl.getExpressionUtilities().createExpressions(expressions, variables, functionList, functionDefinitions, "coordtemp");
4585
4586
4587
4588
4589
4590
        compute<<"}\n";
    }
    index = 0;
    for (map<string, vector<int> >::const_iterator iter = distances.begin(); iter != distances.end(); ++iter, ++index) {
        const vector<int>& atoms = iter->second;
        string deltaName = atomNames[atoms[0]]+atomNames[atoms[1]];
4591
        string value = "(dEdDistance"+cl.intToString(index)+"/r_"+deltaName+")*delta"+deltaName+".xyz";
4592
4593
4594
4595
4596
4597
4598
4599
4600
        compute<<forceNames[atoms[0]]<<".xyz += "<<"-"<<value<<";\n";
        compute<<forceNames[atoms[1]]<<".xyz += "<<value<<";\n";
    }
    index = 0;
    for (map<string, vector<int> >::const_iterator iter = angles.begin(); iter != angles.end(); ++iter, ++index) {
        const vector<int>& atoms = iter->second;
        string deltaName1 = atomNames[atoms[1]]+atomNames[atoms[0]];
        string deltaName2 = atomNames[atoms[1]]+atomNames[atoms[2]];
        compute<<"{\n";
4601
4602
4603
4604
4605
        compute<<"real4 crossProd = cross(delta"<<deltaName2<<", delta"<<deltaName1<<");\n";
        compute<<"real lengthCross = max(length(crossProd), (real) 1e-6f);\n";
        compute<<"real4 deltaCross0 = -cross(delta"<<deltaName1<<", crossProd)*dEdAngle"<<cl.intToString(index)<<"/(delta"<<deltaName1<<".w*lengthCross);\n";
        compute<<"real4 deltaCross2 = cross(delta"<<deltaName2<<", crossProd)*dEdAngle"<<cl.intToString(index)<<"/(delta"<<deltaName2<<".w*lengthCross);\n";
        compute<<"real4 deltaCross1 = -(deltaCross0+deltaCross2);\n";
4606
4607
4608
4609
4610
4611
4612
4613
4614
4615
4616
4617
4618
4619
        compute<<forceNames[atoms[0]]<<".xyz += deltaCross0.xyz;\n";
        compute<<forceNames[atoms[1]]<<".xyz += deltaCross1.xyz;\n";
        compute<<forceNames[atoms[2]]<<".xyz += deltaCross2.xyz;\n";
        compute<<"}\n";
    }
    index = 0;
    for (map<string, vector<int> >::const_iterator iter = dihedrals.begin(); iter != dihedrals.end(); ++iter, ++index) {
        const vector<int>& atoms = iter->second;
        string deltaName1 = atomNames[atoms[0]]+atomNames[atoms[1]];
        string deltaName2 = atomNames[atoms[2]]+atomNames[atoms[1]];
        string deltaName3 = atomNames[atoms[2]]+atomNames[atoms[3]];
        string crossName1 = "cross_"+deltaName1+"_"+deltaName2;
        string crossName2 = "cross_"+deltaName2+"_"+deltaName3;
        compute<<"{\n";
4620
4621
4622
        compute<<"real r = SQRT(delta"<<deltaName2<<".w);\n";
        compute<<"real4 ff;\n";
        compute<<"ff.x = (-dEdDihedral"<<cl.intToString(index)<<"*r)/"<<crossName1<<".w;\n";
4623
4624
        compute<<"ff.y = (delta"<<deltaName1<<".x*delta"<<deltaName2<<".x + delta"<<deltaName1<<".y*delta"<<deltaName2<<".y + delta"<<deltaName1<<".z*delta"<<deltaName2<<".z)/delta"<<deltaName2<<".w;\n";
        compute<<"ff.z = (delta"<<deltaName3<<".x*delta"<<deltaName2<<".x + delta"<<deltaName3<<".y*delta"<<deltaName2<<".y + delta"<<deltaName3<<".z*delta"<<deltaName2<<".z)/delta"<<deltaName2<<".w;\n";
4625
4626
4627
4628
        compute<<"ff.w = (dEdDihedral"<<cl.intToString(index)<<"*r)/"<<crossName2<<".w;\n";
        compute<<"real4 internalF0 = ff.x*"<<crossName1<<";\n";
        compute<<"real4 internalF3 = ff.w*"<<crossName2<<";\n";
        compute<<"real4 s = ff.y*internalF0 - ff.z*internalF3;\n";
4629
4630
4631
4632
4633
4634
4635
4636
        compute<<forceNames[atoms[0]]<<".xyz += internalF0.xyz;\n";
        compute<<forceNames[atoms[1]]<<".xyz += s.xyz-internalF0.xyz;\n";
        compute<<forceNames[atoms[2]]<<".xyz += -s.xyz-internalF3.xyz;\n";
        compute<<forceNames[atoms[3]]<<".xyz += internalF3.xyz;\n";
        compute<<"}\n";
    }
    cl.getBondedUtilities().addInteraction(atoms, compute.str(), force.getForceGroup());
    map<string, string> replacements;
4637
    replacements["M_PI"] = cl.doubleToString(M_PI);
4638
4639
4640
4641
4642
4643
4644
4645
4646
4647
4648
4649
4650
4651
4652
4653
4654
4655
    cl.getBondedUtilities().addPrefixCode(cl.replaceStrings(OpenCLKernelSources::customCompoundBond, replacements));;
}

double OpenCLCalcCustomCompoundBondForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
    if (globals != NULL) {
        bool changed = false;
        for (int i = 0; i < (int) globalParamNames.size(); i++) {
            cl_float value = (cl_float) context.getParameter(globalParamNames[i]);
            if (value != globalParamValues[i])
                changed = true;
            globalParamValues[i] = value;
        }
        if (changed)
            globals->upload(globalParamValues);
    }
    return 0.0;
}

4656
4657
4658
4659
4660
4661
void OpenCLCalcCustomCompoundBondForceKernel::copyParametersToContext(ContextImpl& context, const CustomCompoundBondForce& force) {
    int numContexts = cl.getPlatformData().contexts.size();
    int startIndex = cl.getContextIndex()*force.getNumBonds()/numContexts;
    int endIndex = (cl.getContextIndex()+1)*force.getNumBonds()/numContexts;
    if (numBonds != endIndex-startIndex)
        throw OpenMMException("updateParametersInContext: The number of bonds has changed");
4662
4663
    if (numBonds == 0)
        return;
4664
4665
4666
4667
4668
4669
4670
4671
4672
4673
4674
4675
4676
4677
4678
4679
4680
4681
4682
    
    // Record the per-bond parameters.
    
    vector<vector<cl_float> > paramVector(numBonds);
    vector<int> particles;
    vector<double> parameters;
    for (int i = 0; i < numBonds; i++) {
        force.getBondParameters(startIndex+i, particles, parameters);
        paramVector[i].resize(parameters.size());
        for (int j = 0; j < (int) parameters.size(); j++)
            paramVector[i][j] = (cl_float) parameters[j];
    }
    params->setParameterValues(paramVector);
    
    // Mark that the current reordering may be invalid.
    
    cl.invalidateMolecules();
}

4683
4684
4685
4686
4687
4688
4689
4690
4691
4692
4693
4694
4695
4696
4697
4698
4699
4700
4701
4702
4703
4704
4705
4706
4707
4708
4709
4710
4711
4712
4713
4714
4715
4716
4717
4718
4719
4720
4721
4722
4723
4724
4725
4726
4727
4728
4729
4730
4731
4732
4733
4734
4735
4736
4737
4738
4739
4740
4741
4742
4743
4744
4745
4746
4747
4748
4749
4750
4751
4752
4753
4754
4755
4756
4757
4758
4759
4760
4761
4762
4763
4764
4765
4766
4767
4768
4769
4770
4771
4772
4773
4774
4775
4776
4777
4778
4779
4780
4781
4782
4783
4784
4785
4786
4787
4788
4789
4790
4791
4792
4793
4794
4795
4796
4797
4798
4799
4800
4801
4802
4803
4804
4805
4806
4807
4808
4809
4810
4811
4812
4813
4814
4815
4816
4817
4818
4819
4820
4821
4822
4823
4824
4825
4826
4827
4828
4829
4830
4831
4832
4833
4834
4835
4836
4837
4838
4839
4840
4841
4842
4843
4844
4845
4846
4847
4848
4849
4850
4851
4852
4853
4854
4855
4856
4857
4858
4859
4860
4861
4862
4863
4864
4865
4866
4867
4868
4869
4870
4871
4872
4873
4874
4875
4876
4877
4878
4879
4880
4881
4882
4883
4884
4885
4886
4887
4888
4889
4890
4891
4892
4893
4894
4895
4896
4897
4898
4899
4900
4901
4902
4903
4904
4905
4906
4907
4908
4909
4910
4911
4912
class OpenCLCustomManyParticleForceInfo : public OpenCLForceInfo {
public:
    OpenCLCustomManyParticleForceInfo(const CustomManyParticleForce& force) : OpenCLForceInfo(0), force(force) {
    }
    bool areParticlesIdentical(int particle1, int particle2) {
        vector<double> params1, params2;
        int type1, type2;
        force.getParticleParameters(particle1, params1, type1);
        force.getParticleParameters(particle2, params2, type2);
        if (type1 != type2)
            return false;
        for (int i = 0; i < (int) params1.size(); i++)
            if (params1[i] != params2[i])
                return false;
        return true;
    }
    int getNumParticleGroups() {
        return force.getNumExclusions();
    }
    void getParticlesInGroup(int index, vector<int>& particles) {
        int particle1, particle2;
        force.getExclusionParticles(index, particle1, particle2);
        particles.resize(2);
        particles[0] = particle1;
        particles[1] = particle2;
    }
    bool areGroupsIdentical(int group1, int group2) {
        return true;
    }
private:
    const CustomManyParticleForce& force;
};

OpenCLCalcCustomManyParticleForceKernel::~OpenCLCalcCustomManyParticleForceKernel() {
    if (params != NULL)
        delete params;
    if (globals != NULL)
        delete globals;
    if (orderIndex != NULL)
        delete orderIndex;
    if (particleOrder != NULL)
        delete particleOrder;
    if (particleTypes != NULL)
        delete particleTypes;
    if (exclusions != NULL)
        delete exclusions;
    if (exclusionStartIndex != NULL)
        delete exclusionStartIndex;
    if (blockCenter != NULL)
        delete blockCenter;
    if (blockBoundingBox != NULL)
        delete blockBoundingBox;
    if (neighborPairs != NULL)
        delete neighborPairs;
    if (numNeighborPairs != NULL)
        delete numNeighborPairs;
    if (neighborStartIndex != NULL)
        delete neighborStartIndex;
    if (neighbors != NULL)
        delete neighbors;
    if (numNeighborsForAtom != NULL)
        delete numNeighborsForAtom;
    for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
        delete tabulatedFunctions[i];
}

void OpenCLCalcCustomManyParticleForceKernel::initialize(const System& system, const CustomManyParticleForce& force) {
    if (!cl.getSupports64BitGlobalAtomics())
        throw OpenMMException("CustomManyParticleForce requires a device that supports 64 bit atomic operations");
    int numParticles = force.getNumParticles();
    int particlesPerSet = force.getNumParticlesPerSet();
    bool centralParticleMode = (force.getPermutationMode() == CustomManyParticleForce::UniqueCentralParticle);
    nonbondedMethod = CalcCustomManyParticleForceKernel::NonbondedMethod(force.getNonbondedMethod());
    forceWorkgroupSize = 128;
    findNeighborsWorkgroupSize = (cl.getSIMDWidth() >= 32 ? 128 : 32);
    
    // Record parameter values.
    
    params = new OpenCLParameterSet(cl, force.getNumPerParticleParameters(), numParticles, "customManyParticleParameters");
    vector<vector<float> > paramVector(numParticles);
    for (int i = 0; i < numParticles; i++) {
        vector<double> parameters;
        int type;
        force.getParticleParameters(i, parameters, type);
        paramVector[i].resize(parameters.size());
        for (int j = 0; j < (int) parameters.size(); j++)
            paramVector[i][j] = (float) parameters[j];
    }
    params->setParameterValues(paramVector);
    cl.addForce(new OpenCLCustomManyParticleForceInfo(force));

    // Record the tabulated functions.

    map<string, Lepton::CustomFunction*> functions;
    vector<pair<string, string> > functionDefinitions;
    vector<const TabulatedFunction*> functionList;
    stringstream tableArgs;
    for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
        functionList.push_back(&force.getTabulatedFunction(i));
        string name = force.getTabulatedFunctionName(i);
        string arrayName = "table"+cl.intToString(i);
        functionDefinitions.push_back(make_pair(name, arrayName));
        functions[name] = cl.getExpressionUtilities().getFunctionPlaceholder(force.getTabulatedFunction(i));
        int width;
        vector<float> f = cl.getExpressionUtilities().computeFunctionCoefficients(force.getTabulatedFunction(i), width);
        tabulatedFunctions.push_back(OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction"));
        tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
        tableArgs << ", __global const float";
        if (width > 1)
            tableArgs << width;
        tableArgs << "* restrict " << arrayName;
    }
    
    // Record information about parameters.

    globalParamNames.resize(force.getNumGlobalParameters());
    globalParamValues.resize(force.getNumGlobalParameters());
    for (int i = 0; i < force.getNumGlobalParameters(); i++) {
        globalParamNames[i] = force.getGlobalParameterName(i);
        globalParamValues[i] = (float) force.getGlobalParameterDefaultValue(i);
    }
    vector<pair<ExpressionTreeNode, string> > variables;
    for (int i = 0; i < particlesPerSet; i++) {
        string index = cl.intToString(i+1);
        variables.push_back(makeVariable("x"+index, "pos"+index+".x"));
        variables.push_back(makeVariable("y"+index, "pos"+index+".y"));
        variables.push_back(makeVariable("z"+index, "pos"+index+".z"));
    }
    for (int i = 0; i < force.getNumPerParticleParameters(); i++) {
        const string& name = force.getPerParticleParameterName(i);
        for (int j = 0; j < particlesPerSet; j++) {
            string index = cl.intToString(j+1);
            variables.push_back(makeVariable(name+index, "params"+params->getParameterSuffix(i, index)));
        }
    }
    if (force.getNumGlobalParameters() > 0) {
        globals = OpenCLArray::create<cl_float>(cl, force.getNumGlobalParameters(), "customManyParticleGlobals", CL_MEM_READ_ONLY);
        globals->upload(globalParamValues);
        for (int i = 0; i < force.getNumGlobalParameters(); i++) {
            const string& name = force.getGlobalParameterName(i);
            string value = "globals["+cl.intToString(i)+"]";
            variables.push_back(makeVariable(name, value));
        }
    }
    
    // Build data structures for type filters.
    
    vector<int> particleTypesVec;
    vector<int> orderIndexVec;
    vector<std::vector<int> > particleOrderVec;
    int numTypes;
    CustomManyParticleForceImpl::buildFilterArrays(force, numTypes, particleTypesVec, orderIndexVec, particleOrderVec);
    bool hasTypeFilters = (particleOrderVec.size() > 1);
    if (hasTypeFilters) {
        particleTypes = OpenCLArray::create<int>(cl, particleTypesVec.size(), "customManyParticleTypes");
        orderIndex = OpenCLArray::create<int>(cl, orderIndexVec.size(), "customManyParticleOrderIndex");
        particleOrder = OpenCLArray::create<int>(cl, particleOrderVec.size()*particlesPerSet, "customManyParticleOrder");
        particleTypes->upload(particleTypesVec);
        orderIndex->upload(orderIndexVec);
        vector<int> flattenedOrder(particleOrder->getSize());
        for (int i = 0; i < (int) particleOrderVec.size(); i++)
            for (int j = 0; j < particlesPerSet; j++)
                flattenedOrder[i*particlesPerSet+j] = particleOrderVec[i][j];
        particleOrder->upload(flattenedOrder);
    }
    
    // Build data structures for exclusions.
    
    if (force.getNumExclusions() > 0) {
        vector<vector<int> > particleExclusions(numParticles);
        for (int i = 0; i < force.getNumExclusions(); i++) {
            int p1, p2;
            force.getExclusionParticles(i, p1, p2);
            particleExclusions[p1].push_back(p2);
            particleExclusions[p2].push_back(p1);
        }
        vector<int> exclusionsVec;
        vector<int> exclusionStartIndexVec(numParticles+1);
        exclusionStartIndexVec[0] = 0;
        for (int i = 0; i < numParticles; i++) {
            sort(particleExclusions[i].begin(), particleExclusions[i].end());
            exclusionsVec.insert(exclusionsVec.end(), particleExclusions[i].begin(), particleExclusions[i].end());
            exclusionStartIndexVec[i+1] = exclusionsVec.size();
        }
        exclusions = OpenCLArray::create<int>(cl, exclusionsVec.size(), "customManyParticleExclusions");
        exclusionStartIndex = OpenCLArray::create<int>(cl, exclusionStartIndexVec.size(), "customManyParticleExclusionStart");
        exclusions->upload(exclusionsVec);
        exclusionStartIndex->upload(exclusionStartIndexVec);
    }
    
    // Build data structures for the neighbor list.
    
    if (nonbondedMethod != NoCutoff) {
        int numAtomBlocks = cl.getNumAtomBlocks();
        int elementSize = (cl.getUseDoublePrecision() ? sizeof(double) : sizeof(float));
        blockCenter = new OpenCLArray(cl, numAtomBlocks, 4*elementSize, "blockCenter");
        blockBoundingBox = new OpenCLArray(cl, numAtomBlocks, 4*elementSize, "blockBoundingBox");
        numNeighborPairs = OpenCLArray::create<int>(cl, 1, "customManyParticleNumNeighborPairs");
        neighborStartIndex = OpenCLArray::create<int>(cl, numParticles+1, "customManyParticleNeighborStartIndex");
        numNeighborsForAtom = OpenCLArray::create<int>(cl, numParticles, "customManyParticleNumNeighborsForAtom");

        // Select a size for the array that holds the neighbor list.  We have to make a fairly
        // arbitrary guess, but if this turns out to be too small we'll increase it later.

        maxNeighborPairs = 150*numParticles;
        neighborPairs = OpenCLArray::create<mm_int2>(cl, maxNeighborPairs, "customManyParticleNeighborPairs");
        neighbors = OpenCLArray::create<int>(cl, maxNeighborPairs, "customManyParticleNeighbors");
    }

    // Now to generate the kernel.  First, it needs to calculate all distances, angles,
    // and dihedrals the expression depends on.

    map<string, vector<int> > distances;
    map<string, vector<int> > angles;
    map<string, vector<int> > dihedrals;
    Lepton::ParsedExpression energyExpression = CustomManyParticleForceImpl::prepareExpression(force, functions, distances, angles, dihedrals);
    map<string, Lepton::ParsedExpression> forceExpressions;
    set<string> computedDeltas;
    vector<string> atomNames, posNames;
    for (int i = 0; i < particlesPerSet; i++) {
        string index = cl.intToString(i+1);
        atomNames.push_back("P"+index);
        posNames.push_back("pos"+index);
    }
    stringstream compute;
    int index = 0;
    for (map<string, vector<int> >::const_iterator iter = distances.begin(); iter != distances.end(); ++iter, ++index) {
        const vector<int>& atoms = iter->second;
        string deltaName = atomNames[atoms[0]]+atomNames[atoms[1]];
        if (computedDeltas.count(deltaName) == 0) {
4913
            compute<<"real4 delta"<<deltaName<<" = delta("<<posNames[atoms[0]]<<", "<<posNames[atoms[1]]<<", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);\n";
4914
4915
4916
4917
4918
4919
4920
4921
4922
4923
4924
4925
4926
            computedDeltas.insert(deltaName);
        }
        compute<<"real r_"<<deltaName<<" = sqrt(delta"<<deltaName<<".w);\n";
        variables.push_back(makeVariable(iter->first, "r_"+deltaName));
        forceExpressions["real dEdDistance"+cl.intToString(index)+" = "] = energyExpression.differentiate(iter->first).optimize();
    }
    index = 0;
    for (map<string, vector<int> >::const_iterator iter = angles.begin(); iter != angles.end(); ++iter, ++index) {
        const vector<int>& atoms = iter->second;
        string deltaName1 = atomNames[atoms[1]]+atomNames[atoms[0]];
        string deltaName2 = atomNames[atoms[1]]+atomNames[atoms[2]];
        string angleName = "angle_"+atomNames[atoms[0]]+atomNames[atoms[1]]+atomNames[atoms[2]];
        if (computedDeltas.count(deltaName1) == 0) {
4927
            compute<<"real4 delta"<<deltaName1<<" = delta("<<posNames[atoms[1]]<<", "<<posNames[atoms[0]]<<", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);\n";
4928
4929
4930
            computedDeltas.insert(deltaName1);
        }
        if (computedDeltas.count(deltaName2) == 0) {
4931
            compute<<"real4 delta"<<deltaName2<<" = delta("<<posNames[atoms[1]]<<", "<<posNames[atoms[2]]<<", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);\n";
4932
4933
4934
4935
4936
4937
4938
4939
4940
4941
4942
4943
4944
4945
4946
4947
            computedDeltas.insert(deltaName2);
        }
        compute<<"real "<<angleName<<" = computeAngle(delta"<<deltaName1<<", delta"<<deltaName2<<");\n";
        variables.push_back(makeVariable(iter->first, angleName));
        forceExpressions["real dEdAngle"+cl.intToString(index)+" = "] = energyExpression.differentiate(iter->first).optimize();
    }
    index = 0;
    for (map<string, vector<int> >::const_iterator iter = dihedrals.begin(); iter != dihedrals.end(); ++iter, ++index) {
        const vector<int>& atoms = iter->second;
        string deltaName1 = atomNames[atoms[0]]+atomNames[atoms[1]];
        string deltaName2 = atomNames[atoms[2]]+atomNames[atoms[1]];
        string deltaName3 = atomNames[atoms[2]]+atomNames[atoms[3]];
        string crossName1 = "cross_"+deltaName1+"_"+deltaName2;
        string crossName2 = "cross_"+deltaName2+"_"+deltaName3;
        string dihedralName = "dihedral_"+atomNames[atoms[0]]+atomNames[atoms[1]]+atomNames[atoms[2]]+atomNames[atoms[3]];
        if (computedDeltas.count(deltaName1) == 0) {
4948
            compute<<"real4 delta"<<deltaName1<<" = delta("<<posNames[atoms[0]]<<", "<<posNames[atoms[1]]<<", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);\n";
4949
4950
4951
            computedDeltas.insert(deltaName1);
        }
        if (computedDeltas.count(deltaName2) == 0) {
4952
            compute<<"real4 delta"<<deltaName2<<" = delta("<<posNames[atoms[2]]<<", "<<posNames[atoms[1]]<<", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);\n";
4953
4954
4955
            computedDeltas.insert(deltaName2);
        }
        if (computedDeltas.count(deltaName3) == 0) {
4956
            compute<<"real4 delta"<<deltaName3<<" = delta("<<posNames[atoms[2]]<<", "<<posNames[atoms[3]]<<", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ);\n";
4957
4958
4959
4960
4961
4962
4963
4964
4965
4966
4967
4968
4969
4970
4971
4972
4973
4974
4975
4976
4977
4978
4979
4980
4981
4982
4983
4984
4985
4986
4987
4988
4989
4990
4991
4992
4993
4994
4995
4996
4997
4998
4999
5000
5001
5002
5003
5004
5005
5006
5007
5008
5009
5010
5011
5012
5013
            computedDeltas.insert(deltaName3);
        }
        compute<<"real4 "<<crossName1<<" = computeCross(delta"<<deltaName1<<", delta"<<deltaName2<<");\n";
        compute<<"real4 "<<crossName2<<" = computeCross(delta"<<deltaName2<<", delta"<<deltaName3<<");\n";
        compute<<"real "<<dihedralName<<" = computeAngle("<<crossName1<<", "<<crossName2<<");\n";
        compute<<dihedralName<<" *= (delta"<<deltaName1<<".x*"<<crossName2<<".x + delta"<<deltaName1<<".y*"<<crossName2<<".y + delta"<<deltaName1<<".z*"<<crossName2<<".z < 0 ? -1 : 1);\n";
        variables.push_back(makeVariable(iter->first, dihedralName));
        forceExpressions["real dEdDihedral"+cl.intToString(index)+" = "] = energyExpression.differentiate(iter->first).optimize();
    }

    // Now evaluate the expressions.

    for (int i = 0; i < (int) params->getBuffers().size(); i++) {
        OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i];
        compute<<buffer.getType()<<" params"<<(i+1)<<" = global_params"<<(i+1)<<"[index];\n";
    }
    forceExpressions["energy += "] = energyExpression;
    compute << cl.getExpressionUtilities().createExpressions(forceExpressions, variables, functionList, functionDefinitions, "temp");

    // Apply forces to atoms.

    vector<string> forceNames;
    for (int i = 0; i < particlesPerSet; i++) {
        string istr = cl.intToString(i+1);
        string forceName = "force"+istr;
        forceNames.push_back(forceName);
        compute<<"real4 "<<forceName<<" = (real4) 0;\n";
        compute<<"{\n";
        Lepton::ParsedExpression forceExpressionX = energyExpression.differentiate("x"+istr).optimize();
        Lepton::ParsedExpression forceExpressionY = energyExpression.differentiate("y"+istr).optimize();
        Lepton::ParsedExpression forceExpressionZ = energyExpression.differentiate("z"+istr).optimize();
        map<string, Lepton::ParsedExpression> expressions;
        if (!isZeroExpression(forceExpressionX))
            expressions[forceName+".x -= "] = forceExpressionX;
        if (!isZeroExpression(forceExpressionY))
            expressions[forceName+".y -= "] = forceExpressionY;
        if (!isZeroExpression(forceExpressionZ))
            expressions[forceName+".z -= "] = forceExpressionZ;
        if (expressions.size() > 0)
            compute<<cl.getExpressionUtilities().createExpressions(expressions, variables, functionList, functionDefinitions, "coordtemp");
        compute<<"}\n";
    }
    index = 0;
    for (map<string, vector<int> >::const_iterator iter = distances.begin(); iter != distances.end(); ++iter, ++index) {
        const vector<int>& atoms = iter->second;
        string deltaName = atomNames[atoms[0]]+atomNames[atoms[1]];
        string value = "(dEdDistance"+cl.intToString(index)+"/r_"+deltaName+")*delta"+deltaName+".xyz";
        compute<<forceNames[atoms[0]]<<".xyz += "<<"-"<<value<<";\n";
        compute<<forceNames[atoms[1]]<<".xyz += "<<value<<";\n";
    }
    index = 0;
    for (map<string, vector<int> >::const_iterator iter = angles.begin(); iter != angles.end(); ++iter, ++index) {
        const vector<int>& atoms = iter->second;
        string deltaName1 = atomNames[atoms[1]]+atomNames[atoms[0]];
        string deltaName2 = atomNames[atoms[1]]+atomNames[atoms[2]];
        compute<<"{\n";
        compute<<"real4 crossProd = cross(delta"<<deltaName2<<", delta"<<deltaName1<<");\n";
5014
        compute<<"real lengthCross = max(SQRT(dot(crossProd, crossProd)), (real) 1e-6f);\n";
5015
5016
5017
5018
5019
5020
5021
5022
5023
5024
5025
5026
5027
5028
5029
5030
5031
5032
5033
5034
5035
5036
5037
5038
5039
5040
5041
5042
5043
5044
5045
5046
5047
5048
5049
5050
5051
5052
5053
5054
5055
5056
5057
5058
5059
5060
5061
5062
5063
5064
5065
5066
        compute<<"real4 deltaCross0 = -cross(delta"<<deltaName1<<", crossProd)*dEdAngle"<<cl.intToString(index)<<"/(delta"<<deltaName1<<".w*lengthCross);\n";
        compute<<"real4 deltaCross2 = cross(delta"<<deltaName2<<", crossProd)*dEdAngle"<<cl.intToString(index)<<"/(delta"<<deltaName2<<".w*lengthCross);\n";
        compute<<"real4 deltaCross1 = -(deltaCross0+deltaCross2);\n";
        compute<<forceNames[atoms[0]]<<".xyz += deltaCross0.xyz;\n";
        compute<<forceNames[atoms[1]]<<".xyz += deltaCross1.xyz;\n";
        compute<<forceNames[atoms[2]]<<".xyz += deltaCross2.xyz;\n";
        compute<<"}\n";
    }
    index = 0;
    for (map<string, vector<int> >::const_iterator iter = dihedrals.begin(); iter != dihedrals.end(); ++iter, ++index) {
        const vector<int>& atoms = iter->second;
        string deltaName1 = atomNames[atoms[0]]+atomNames[atoms[1]];
        string deltaName2 = atomNames[atoms[2]]+atomNames[atoms[1]];
        string deltaName3 = atomNames[atoms[2]]+atomNames[atoms[3]];
        string crossName1 = "cross_"+deltaName1+"_"+deltaName2;
        string crossName2 = "cross_"+deltaName2+"_"+deltaName3;
        compute<<"{\n";
        compute<<"real r = sqrt(delta"<<deltaName2<<".w);\n";
        compute<<"real4 ff;\n";
        compute<<"ff.x = (-dEdDihedral"<<cl.intToString(index)<<"*r)/"<<crossName1<<".w;\n";
        compute<<"ff.y = (delta"<<deltaName1<<".x*delta"<<deltaName2<<".x + delta"<<deltaName1<<".y*delta"<<deltaName2<<".y + delta"<<deltaName1<<".z*delta"<<deltaName2<<".z)/delta"<<deltaName2<<".w;\n";
        compute<<"ff.z = (delta"<<deltaName3<<".x*delta"<<deltaName2<<".x + delta"<<deltaName3<<".y*delta"<<deltaName2<<".y + delta"<<deltaName3<<".z*delta"<<deltaName2<<".z)/delta"<<deltaName2<<".w;\n";
        compute<<"ff.w = (dEdDihedral"<<cl.intToString(index)<<"*r)/"<<crossName2<<".w;\n";
        compute<<"real4 internalF0 = ff.x*"<<crossName1<<";\n";
        compute<<"real4 internalF3 = ff.w*"<<crossName2<<";\n";
        compute<<"real4 s = ff.y*internalF0 - ff.z*internalF3;\n";
        compute<<forceNames[atoms[0]]<<".xyz += internalF0.xyz;\n";
        compute<<forceNames[atoms[1]]<<".xyz += s.xyz-internalF0.xyz;\n";
        compute<<forceNames[atoms[2]]<<".xyz += -s.xyz-internalF3.xyz;\n";
        compute<<forceNames[atoms[3]]<<".xyz += internalF3.xyz;\n";
        compute<<"}\n";
    }
    
    // Store forces to global memory.
    
    for (int i = 0; i < particlesPerSet; i++)
        compute<<"storeForce(atom"<<(i+1)<<", "<<forceNames[i]<<", forceBuffers);\n";
    
    // Create other replacements that depend on the number of particles per set.
    
    stringstream numCombinations, atomsForCombination, isValidCombination, permute, loadData, verifyCutoff, verifyExclusions;
    if (hasTypeFilters) {
        permute<<"int particleSet[] = {";
        for (int i = 0; i < particlesPerSet; i++) {
            permute<<"p"<<(i+1);
            if (i < particlesPerSet-1)
                permute<<", ";
        }
        permute<<"};\n";
    }
    for (int i = 0; i < particlesPerSet; i++) {
        if (hasTypeFilters)
peastman's avatar
Bug fix  
peastman committed
5067
            permute<<"int atom"<<(i+1)<<" = particleSet[particleOrder["<<particlesPerSet<<"*order+"<<i<<"]];\n";
5068
5069
5070
5071
5072
5073
5074
5075
5076
5077
5078
5079
5080
5081
5082
5083
5084
5085
5086
5087
5088
5089
5090
5091
5092
5093
5094
5095
5096
5097
5098
5099
5100
5101
5102
5103
5104
5105
5106
5107
5108
5109
5110
5111
5112
5113
5114
5115
5116
5117
5118
5119
5120
5121
5122
5123
5124
5125
        else
            permute<<"int atom"<<(i+1)<<" = p"<<(i+1)<<";\n";
        loadData<<"real4 pos"<<(i+1)<<" = posq[atom"<<(i+1)<<"];\n";
        for (int j = 0; j < (int) params->getBuffers().size(); j++)
            loadData<<params->getBuffers()[j].getType()<<" params"<<(j+1)<<(i+1)<<" = global_params"<<(j+1)<<"[atom"<<(i+1)<<"];\n";
    }
    if (centralParticleMode) {
        for (int i = 1; i < particlesPerSet; i++) {
            if (i > 1)
                isValidCombination<<" && p"<<(i+1)<<">p"<<i<<" && ";
            isValidCombination<<"p"<<(i+1)<<"!=p1";
        }
    }
    else {
        for (int i = 2; i < particlesPerSet; i++) {
            if (i > 2)
                isValidCombination<<" && ";
            isValidCombination<<"a"<<(i+1)<<">a"<<i;
        }
    }
    atomsForCombination<<"int tempIndex = index;\n";
    for (int i = 1; i < particlesPerSet; i++) {
        if (i > 1)
            numCombinations<<"*";
        numCombinations<<"numNeighbors";
        if (centralParticleMode)
            atomsForCombination<<"int a"<<(i+1)<<" = tempIndex%numNeighbors;\n";
        else
            atomsForCombination<<"int a"<<(i+1)<<" = 1+tempIndex%numNeighbors;\n";
        if (i < particlesPerSet-1)
            atomsForCombination<<"tempIndex /= numNeighbors;\n";
    }
    if (particlesPerSet > 2) {
        if (centralParticleMode)
            atomsForCombination<<"a2 = (a3%2 == 0 ? a2 : numNeighbors-a2-1);\n";
        else
            atomsForCombination<<"a2 = (a3%2 == 0 ? a2 : numNeighbors-a2+1);\n";
    }
    for (int i = 1; i < particlesPerSet; i++) {
        if (nonbondedMethod == NoCutoff) {
            if (centralParticleMode)
                atomsForCombination<<"int p"<<(i+1)<<" = a"<<(i+1)<<";\n";
            else
                atomsForCombination<<"int p"<<(i+1)<<" = p1+a"<<(i+1)<<";\n";
        }
        else {
            if (centralParticleMode)
                atomsForCombination<<"int p"<<(i+1)<<" = neighbors[firstNeighbor+a"<<(i+1)<<"];\n";
            else
                atomsForCombination<<"int p"<<(i+1)<<" = neighbors[firstNeighbor-1+a"<<(i+1)<<"];\n";
        }
    }
    if (nonbondedMethod != NoCutoff) {
        for (int i = 1; i < particlesPerSet; i++)
            verifyCutoff<<"real4 pos"<<(i+1)<<" = posq[p"<<(i+1)<<"];\n";
        if (!centralParticleMode) {
            for (int i = 1; i < particlesPerSet; i++) {
                for (int j = i+1; j < particlesPerSet; j++)
5126
                    verifyCutoff<<"includeInteraction &= (delta(pos"<<(i+1)<<", pos"<<(j+1)<<", periodicBoxSize, invPeriodicBoxSize, periodicBoxVecX, periodicBoxVecY, periodicBoxVecZ).w < CUTOFF_SQUARED);\n";
5127
5128
5129
5130
5131
5132
5133
5134
5135
5136
5137
5138
5139
5140
5141
5142
5143
5144
5145
5146
5147
5148
5149
5150
5151
5152
5153
5154
5155
5156
5157
5158
5159
5160
5161
5162
5163
5164
5165
5166
5167
5168
5169
5170
5171
5172
5173
5174
5175
5176
5177
5178
5179
5180
5181
5182
5183
5184
5185
5186
5187
5188
5189
5190
5191
5192
5193
5194
5195
            }
        }
    }
    if (force.getNumExclusions() > 0) {
        int startCheckFrom = (nonbondedMethod == NoCutoff ? 0 : 1);
        for (int i = startCheckFrom; i < particlesPerSet; i++)
            for (int j = i+1; j < particlesPerSet; j++)
                verifyExclusions<<"includeInteraction &= !isInteractionExcluded(p"<<(i+1)<<", p"<<(j+1)<<", exclusions, exclusionStartIndex);\n";
    }
    string computeTypeIndex = "particleTypes[p"+cl.intToString(particlesPerSet)+"]";
    for (int i = particlesPerSet-2; i >= 0; i--)
        computeTypeIndex = "particleTypes[p"+cl.intToString(i+1)+"]+"+cl.intToString(numTypes)+"*("+computeTypeIndex+")";
    
    // Create replacements for extra arguments.
    
    stringstream extraArgs;
    if (force.getNumGlobalParameters() > 0)
        extraArgs << ", __global const float* globals";
    for (int i = 0; i < (int) params->getBuffers().size(); i++) {
        OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i];
        extraArgs<<", __global const "<<buffer.getType()<<"* restrict global_params"<<(i+1);
    }

    // Create the kernels.

    map<string, string> replacements;
    replacements["COMPUTE_INTERACTION"] = compute.str();
    replacements["NUM_CANDIDATE_COMBINATIONS"] = numCombinations.str();
    replacements["FIND_ATOMS_FOR_COMBINATION_INDEX"] = atomsForCombination.str();
    replacements["IS_VALID_COMBINATION"] = isValidCombination.str();
    replacements["VERIFY_CUTOFF"] = verifyCutoff.str();
    replacements["VERIFY_EXCLUSIONS"] = verifyExclusions.str();
    replacements["PERMUTE_ATOMS"] = permute.str();
    replacements["LOAD_PARTICLE_DATA"] = loadData.str();
    replacements["COMPUTE_TYPE_INDEX"] = computeTypeIndex;
    replacements["PARAMETER_ARGUMENTS"] = extraArgs.str()+tableArgs.str();
    map<string, string> defines;
    if (nonbondedMethod != NoCutoff)
        defines["USE_CUTOFF"] = "1";
    if (nonbondedMethod == CutoffPeriodic)
        defines["USE_PERIODIC"] = "1";
    if (centralParticleMode)
        defines["USE_CENTRAL_PARTICLE"] = "1";
    if (hasTypeFilters)
        defines["USE_FILTERS"] = "1";
    if (force.getNumExclusions() > 0)
        defines["USE_EXCLUSIONS"] = "1";
    defines["NUM_ATOMS"] = cl.intToString(cl.getNumAtoms());
    defines["PADDED_NUM_ATOMS"] = cl.intToString(cl.getPaddedNumAtoms());
    defines["M_PI"] = cl.doubleToString(M_PI);
    defines["CUTOFF_SQUARED"] = cl.doubleToString(force.getCutoffDistance()*force.getCutoffDistance());
    defines["TILE_SIZE"] = cl.intToString(OpenCLContext::TileSize);
    defines["NUM_BLOCKS"] = cl.intToString(cl.getNumAtomBlocks());
    defines["FIND_NEIGHBORS_WORKGROUP_SIZE"] = cl.intToString(findNeighborsWorkgroupSize);
    cl::Program program = cl.createProgram(cl.replaceStrings(OpenCLKernelSources::customManyParticle, replacements), defines);
    forceKernel = cl::Kernel(program, "computeInteraction");
    blockBoundsKernel = cl::Kernel(program, "findBlockBounds");
    neighborsKernel = cl::Kernel(program, "findNeighbors");
    startIndicesKernel = cl::Kernel(program, "computeNeighborStartIndices");
    copyPairsKernel = cl::Kernel(program, "copyPairsToNeighborList");
}

double OpenCLCalcCustomManyParticleForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
    if (!hasInitializedKernel) {
        hasInitializedKernel = true;
        
        // Set arguments for the force kernel.
        
        int index = 0;
5196
        forceKernel.setArg<cl::Buffer>(index++, cl.getLongForceBuffer().getDeviceBuffer());
5197
5198
        forceKernel.setArg<cl::Buffer>(index++, cl.getEnergyBuffer().getDeviceBuffer());
        forceKernel.setArg<cl::Buffer>(index++, cl.getPosq().getDeviceBuffer());
5199
5200
        setPeriodicBoxArgs(cl, forceKernel, index);
        index += 5;
5201
5202
5203
5204
5205
5206
5207
5208
5209
5210
5211
5212
5213
5214
5215
5216
5217
5218
5219
5220
5221
5222
5223
5224
5225
5226
        if (nonbondedMethod != NoCutoff) {
            forceKernel.setArg<cl::Buffer>(index++, neighbors->getDeviceBuffer());
            forceKernel.setArg<cl::Buffer>(index++, neighborStartIndex->getDeviceBuffer());
        }
        if (particleTypes != NULL) {
            forceKernel.setArg<cl::Buffer>(index++, particleTypes->getDeviceBuffer());
            forceKernel.setArg<cl::Buffer>(index++, orderIndex->getDeviceBuffer());
            forceKernel.setArg<cl::Buffer>(index++, particleOrder->getDeviceBuffer());
        }
        if (exclusions != NULL) {
            forceKernel.setArg<cl::Buffer>(index++, exclusions->getDeviceBuffer());
            forceKernel.setArg<cl::Buffer>(index++, exclusionStartIndex->getDeviceBuffer());
        }
        if (globals != NULL)
            forceKernel.setArg<cl::Buffer>(index++, globals->getDeviceBuffer());
        for (int i = 0; i < (int) params->getBuffers().size(); i++) {
            OpenCLNonbondedUtilities::ParameterInfo& buffer = params->getBuffers()[i];
            forceKernel.setArg<cl::Memory>(index++, buffer.getMemory());
        }
        for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
            forceKernel.setArg<cl::Buffer>(index++, tabulatedFunctions[i]->getDeviceBuffer());
        
        if (nonbondedMethod != NoCutoff) {
            // Set arguments for the block bounds kernel.

            index = 0;
5227
5228
            setPeriodicBoxArgs(cl, blockBoundsKernel, index);
            index += 5;
5229
5230
5231
5232
5233
5234
5235
5236
            blockBoundsKernel.setArg<cl::Buffer>(index++, cl.getPosq().getDeviceBuffer());
            blockBoundsKernel.setArg<cl::Buffer>(index++, blockCenter->getDeviceBuffer());
            blockBoundsKernel.setArg<cl::Buffer>(index++, blockBoundingBox->getDeviceBuffer());
            blockBoundsKernel.setArg<cl::Buffer>(index++, numNeighborPairs->getDeviceBuffer());

            // Set arguments for the neighbor list kernel.

            index = 0;
5237
5238
            setPeriodicBoxArgs(cl, neighborsKernel, index);
            index += 5;
5239
5240
5241
5242
5243
5244
5245
5246
5247
5248
5249
5250
5251
5252
5253
5254
5255
5256
5257
5258
5259
5260
5261
5262
5263
5264
5265
5266
5267
5268
5269
5270
5271
5272
5273
5274
5275
5276
5277
5278
5279
5280
5281
5282
5283
            neighborsKernel.setArg<cl::Buffer>(index++, cl.getPosq().getDeviceBuffer());
            neighborsKernel.setArg<cl::Buffer>(index++, blockCenter->getDeviceBuffer());
            neighborsKernel.setArg<cl::Buffer>(index++, blockBoundingBox->getDeviceBuffer());
            neighborsKernel.setArg<cl::Buffer>(index++, neighborPairs->getDeviceBuffer());
            neighborsKernel.setArg<cl::Buffer>(index++, numNeighborPairs->getDeviceBuffer());
            neighborsKernel.setArg<cl::Buffer>(index++, numNeighborsForAtom->getDeviceBuffer());
            index++;
            if (exclusions != NULL) {
                neighborsKernel.setArg<cl::Buffer>(index++, exclusions->getDeviceBuffer());
                neighborsKernel.setArg<cl::Buffer>(index++, exclusionStartIndex->getDeviceBuffer());
            }
            
            // Set arguments for the kernel to find neighbor list start indices.
            
            index = 0;
            startIndicesKernel.setArg<cl::Buffer>(index++, numNeighborsForAtom->getDeviceBuffer());
            startIndicesKernel.setArg<cl::Buffer>(index++, neighborStartIndex->getDeviceBuffer());
            startIndicesKernel.setArg<cl::Buffer>(index++, numNeighborPairs->getDeviceBuffer());

            // Set arguments for the kernel to assemble the final neighbor list.
            
            index = 0;
            copyPairsKernel.setArg<cl::Buffer>(index++, neighborPairs->getDeviceBuffer());
            copyPairsKernel.setArg<cl::Buffer>(index++, neighbors->getDeviceBuffer());
            copyPairsKernel.setArg<cl::Buffer>(index++, numNeighborPairs->getDeviceBuffer());
            index++;
            copyPairsKernel.setArg<cl::Buffer>(index++, numNeighborsForAtom->getDeviceBuffer());
            copyPairsKernel.setArg<cl::Buffer>(index++, neighborStartIndex->getDeviceBuffer());
       }
    }
    if (globals != NULL) {
        bool changed = false;
        for (int i = 0; i < (int) globalParamNames.size(); i++) {
            cl_float value = (cl_float) context.getParameter(globalParamNames[i]);
            if (value != globalParamValues[i])
                changed = true;
            globalParamValues[i] = value;
        }
        if (changed)
            globals->upload(globalParamValues);
    }
    while (true) {
        int* numPairs = (int*) cl.getPinnedBuffer();
        cl::Event event;
        if (nonbondedMethod != NoCutoff) {
5284
            neighborsKernel.setArg<int>(11, maxNeighborPairs);
5285
5286
5287
5288
5289
5290
5291
5292
5293
5294
5295
5296
5297
            startIndicesKernel.setArg<int>(3, maxNeighborPairs);
            copyPairsKernel.setArg<int>(3, maxNeighborPairs);
            cl.executeKernel(blockBoundsKernel, cl.getNumAtomBlocks());
            cl.executeKernel(neighborsKernel, cl.getNumAtoms(), findNeighborsWorkgroupSize);

            // We need to make sure there was enough memory for the neighbor list.  Download the
            // information asynchronously so kernels can be running at the same time.

            numNeighborPairs->download(numPairs, false);
            cl.getQueue().enqueueMarker(&event);
            cl.executeKernel(startIndicesKernel, 256, 256);
            cl.executeKernel(copyPairsKernel, maxNeighborPairs);
        }
5298
5299
        int maxThreads = min(cl.getNumAtoms()*forceWorkgroupSize, cl.getEnergyBuffer().getSize());
        cl.executeKernel(forceKernel, maxThreads, forceWorkgroupSize);
5300
5301
5302
5303
5304
5305
5306
5307
5308
5309
5310
5311
5312
5313
        if (nonbondedMethod != NoCutoff) {
            // Make sure there was enough memory for the neighbor list.

            event.wait();
            if (*numPairs > maxNeighborPairs) {
                // Resize the arrays and run the calculation again.

                delete neighborPairs;
                neighborPairs = NULL;
                delete neighbors;
                neighbors = NULL;
                maxNeighborPairs = (int) (1.1*(*numPairs));
                neighborPairs = OpenCLArray::create<mm_int2>(cl, maxNeighborPairs, "customManyParticleNeighborPairs");
                neighbors = OpenCLArray::create<int>(cl, maxNeighborPairs, "customManyParticleNeighbors");
5314
5315
                forceKernel.setArg<cl::Buffer>(8, neighbors->getDeviceBuffer());
                neighborsKernel.setArg<cl::Buffer>(8, neighborPairs->getDeviceBuffer());
5316
5317
                copyPairsKernel.setArg<cl::Buffer>(0, neighborPairs->getDeviceBuffer());
                copyPairsKernel.setArg<cl::Buffer>(1, neighbors->getDeviceBuffer());
5318
5319
5320
5321
5322
5323
5324
5325
5326
5327
5328
5329
5330
5331
5332
5333
5334
5335
5336
5337
5338
5339
5340
5341
5342
5343
5344
5345
5346
5347
5348
                continue;
            }
        }
        break;
    }
    return 0.0;
}

void OpenCLCalcCustomManyParticleForceKernel::copyParametersToContext(ContextImpl& context, const CustomManyParticleForce& force) {
    int numParticles = force.getNumParticles();
    if (numParticles != cl.getNumAtoms())
        throw OpenMMException("updateParametersInContext: The number of particles has changed");
    
    // Record the per-particle parameters.
    
    vector<vector<float> > paramVector(numParticles);
    vector<double> parameters;
    int type;
    for (int i = 0; i < numParticles; i++) {
        force.getParticleParameters(i, parameters, type);
        paramVector[i].resize(parameters.size());
        for (int j = 0; j < (int) parameters.size(); j++)
            paramVector[i][j] = (float) parameters[j];
    }
    params->setParameterValues(paramVector);
    
    // Mark that the current reordering may be invalid.
    
    cl.invalidateMolecules();
}

5349
5350
5351
5352
OpenCLIntegrateVerletStepKernel::~OpenCLIntegrateVerletStepKernel() {
}

void OpenCLIntegrateVerletStepKernel::initialize(const System& system, const VerletIntegrator& integrator) {
5353
    cl.getPlatformData().initializeContexts(system);
5354
    cl::Program program = cl.createProgram(OpenCLKernelSources::verlet, "");
5355
5356
    kernel1 = cl::Kernel(program, "integrateVerletPart1");
    kernel2 = cl::Kernel(program, "integrateVerletPart2");
5357
    prevStepSize = -1.0;
5358
5359
5360
}

void OpenCLIntegrateVerletStepKernel::execute(ContextImpl& context, const VerletIntegrator& integrator) {
5361
    OpenCLIntegrationUtilities& integration = cl.getIntegrationUtilities();
5362
5363
    int numAtoms = cl.getNumAtoms();
    double dt = integrator.getStepSize();
5364
5365
5366
    if (!hasInitializedKernels) {
        hasInitializedKernels = true;
        kernel1.setArg<cl_int>(0, numAtoms);
5367
        kernel1.setArg<cl::Buffer>(1, cl.getIntegrationUtilities().getStepSize().getDeviceBuffer());
5368
        kernel1.setArg<cl::Buffer>(2, cl.getPosq().getDeviceBuffer());
5369
5370
5371
5372
        setPosqCorrectionArg(cl, kernel1, 3);
        kernel1.setArg<cl::Buffer>(4, cl.getVelm().getDeviceBuffer());
        kernel1.setArg<cl::Buffer>(5, cl.getForce().getDeviceBuffer());
        kernel1.setArg<cl::Buffer>(6, integration.getPosDelta().getDeviceBuffer());
5373
        kernel2.setArg<cl_int>(0, numAtoms);
5374
        kernel2.setArg<cl::Buffer>(1, cl.getIntegrationUtilities().getStepSize().getDeviceBuffer());
5375
        kernel2.setArg<cl::Buffer>(2, cl.getPosq().getDeviceBuffer());
5376
5377
5378
        setPosqCorrectionArg(cl, kernel2, 3);
        kernel2.setArg<cl::Buffer>(4, cl.getVelm().getDeviceBuffer());
        kernel2.setArg<cl::Buffer>(5, integration.getPosDelta().getDeviceBuffer());
5379
    }
5380
    if (dt != prevStepSize) {
5381
5382
5383
5384
5385
5386
5387
5388
5389
5390
        if (cl.getUseDoublePrecision() || cl.getUseMixedPrecision()) {
            vector<mm_double2> stepSizeVec(1);
            stepSizeVec[0] = mm_double2(dt, dt);
            cl.getIntegrationUtilities().getStepSize().upload(stepSizeVec);
        }
        else {
            vector<mm_float2> stepSizeVec(1);
            stepSizeVec[0] = mm_float2((cl_float) dt, (cl_float) dt);
            cl.getIntegrationUtilities().getStepSize().upload(stepSizeVec);
        }
5391
5392
        prevStepSize = dt;
    }
5393
5394
5395
5396
5397
5398
5399

    // Call the first integration kernel.

    cl.executeKernel(kernel1, numAtoms);

    // Apply constraints.

5400
    integration.applyConstraints(integrator.getConstraintTolerance());
5401
5402
5403
5404

    // Call the second integration kernel.

    cl.executeKernel(kernel2, numAtoms);
5405
    integration.computeVirtualSites();
5406
5407
5408
5409
5410

    // Update the time and step count.

    cl.setTime(cl.getTime()+dt);
    cl.setStepCount(cl.getStepCount()+1);
5411
    cl.reorderAtoms();
5412
5413
5414
5415
5416
5417
    
    // Reduce UI lag.
    
#ifdef WIN32
    cl.getQueue().flush();
#endif
5418
5419
}

5420
5421
5422
5423
double OpenCLIntegrateVerletStepKernel::computeKineticEnergy(ContextImpl& context, const VerletIntegrator& integrator) {
    return cl.getIntegrationUtilities().computeKineticEnergy(0.5*integrator.getStepSize());
}

5424
5425
5426
5427
5428
5429
OpenCLIntegrateLangevinStepKernel::~OpenCLIntegrateLangevinStepKernel() {
    if (params != NULL)
        delete params;
}

void OpenCLIntegrateLangevinStepKernel::initialize(const System& system, const LangevinIntegrator& integrator) {
5430
    cl.getPlatformData().initializeContexts(system);
5431
5432
    cl.getIntegrationUtilities().initRandomNumberGenerator(integrator.getRandomNumberSeed());
    map<string, string> defines;
5433
5434
    defines["NUM_ATOMS"] = cl.intToString(cl.getNumAtoms());
    defines["PADDED_NUM_ATOMS"] = cl.intToString(cl.getPaddedNumAtoms());
5435
    cl::Program program = cl.createProgram(OpenCLKernelSources::langevin, defines, "");
5436
5437
    kernel1 = cl::Kernel(program, "integrateLangevinPart1");
    kernel2 = cl::Kernel(program, "integrateLangevinPart2");
5438
    params = new OpenCLArray(cl, 3, cl.getUseDoublePrecision() || cl.getUseMixedPrecision() ? sizeof(cl_double) : sizeof(cl_float), "langevinParams");
5439
5440
5441
5442
    prevStepSize = -1.0;
}

void OpenCLIntegrateLangevinStepKernel::execute(ContextImpl& context, const LangevinIntegrator& integrator) {
5443
    OpenCLIntegrationUtilities& integration = cl.getIntegrationUtilities();
5444
    int numAtoms = cl.getNumAtoms();
5445
5446
    if (!hasInitializedKernels) {
        hasInitializedKernels = true;
5447
5448
5449
5450
        kernel1.setArg<cl::Buffer>(0, cl.getVelm().getDeviceBuffer());
        kernel1.setArg<cl::Buffer>(1, cl.getForce().getDeviceBuffer());
        kernel1.setArg<cl::Buffer>(2, integration.getPosDelta().getDeviceBuffer());
        kernel1.setArg<cl::Buffer>(3, params->getDeviceBuffer());
5451
5452
5453
        kernel1.setArg<cl::Buffer>(4, integration.getStepSize().getDeviceBuffer());
        kernel1.setArg<cl::Buffer>(5, integration.getRandom().getDeviceBuffer());
        kernel2.setArg<cl::Buffer>(0, cl.getPosq().getDeviceBuffer());
5454
5455
5456
5457
        setPosqCorrectionArg(cl, kernel2, 1);
        kernel2.setArg<cl::Buffer>(2, integration.getPosDelta().getDeviceBuffer());
        kernel2.setArg<cl::Buffer>(3, cl.getVelm().getDeviceBuffer());
        kernel2.setArg<cl::Buffer>(4, integration.getStepSize().getDeviceBuffer());
5458
    }
5459
5460
5461
5462
5463
5464
5465
5466
    double temperature = integrator.getTemperature();
    double friction = integrator.getFriction();
    double stepSize = integrator.getStepSize();
    if (temperature != prevTemp || friction != prevFriction || stepSize != prevStepSize) {
        // Calculate the integration parameters.

        double tau = (friction == 0.0 ? 0.0 : 1.0/friction);
        double kT = BOLTZ*temperature;
Peter Eastman's avatar
Peter Eastman committed
5467
        double vscale = exp(-stepSize/tau);
5468
        double fscale = (1-vscale)*tau;
Peter Eastman's avatar
Peter Eastman committed
5469
        double noisescale = sqrt(2*kT/tau)*sqrt(0.5*(1-vscale*vscale)*tau);
5470
5471
5472
5473
5474
5475
5476
5477
5478
5479
5480
5481
5482
5483
5484
5485
5486
5487
        if (cl.getUseDoublePrecision() || cl.getUseMixedPrecision()) {
            vector<cl_double> p(params->getSize());
            p[0] = vscale;
            p[1] = fscale;
            p[2] = noisescale;
            params->upload(p);
            mm_double2 ss = mm_double2(0, stepSize);
            integration.getStepSize().upload(&ss);
        }
        else {
            vector<cl_float> p(params->getSize());
            p[0] = (cl_float) vscale;
            p[1] = (cl_float) fscale;
            p[2] = (cl_float) noisescale;
            params->upload(p);
            mm_float2 ss = mm_float2(0, (float) stepSize);
            integration.getStepSize().upload(&ss);
        }
5488
5489
5490
5491
5492
5493
5494
        prevTemp = temperature;
        prevFriction = friction;
        prevStepSize = stepSize;
    }

    // Call the first integration kernel.

5495
    kernel1.setArg<cl_uint>(6, integration.prepareRandomNumbers(cl.getPaddedNumAtoms()));
5496
5497
5498
5499
    cl.executeKernel(kernel1, numAtoms);

    // Apply constraints.

5500
    integration.applyConstraints(integrator.getConstraintTolerance());
5501
5502
5503
5504

    // Call the second integration kernel.

    cl.executeKernel(kernel2, numAtoms);
5505
    integration.computeVirtualSites();
5506
5507
5508
5509
5510

    // Update the time and step count.

    cl.setTime(cl.getTime()+stepSize);
    cl.setStepCount(cl.getStepCount()+1);
5511
    cl.reorderAtoms();
5512
5513
5514
5515
5516
5517
    
    // Reduce UI lag.
    
#ifdef WIN32
    cl.getQueue().flush();
#endif
5518
}
5519

5520
5521
5522
5523
double OpenCLIntegrateLangevinStepKernel::computeKineticEnergy(ContextImpl& context, const LangevinIntegrator& integrator) {
    return cl.getIntegrationUtilities().computeKineticEnergy(0.5*integrator.getStepSize());
}

5524
5525
5526
5527
OpenCLIntegrateBrownianStepKernel::~OpenCLIntegrateBrownianStepKernel() {
}

void OpenCLIntegrateBrownianStepKernel::initialize(const System& system, const BrownianIntegrator& integrator) {
5528
    cl.getPlatformData().initializeContexts(system);
5529
5530
    cl.getIntegrationUtilities().initRandomNumberGenerator(integrator.getRandomNumberSeed());
    map<string, string> defines;
5531
    defines["NUM_ATOMS"] = cl.intToString(cl.getNumAtoms());
5532
    cl::Program program = cl.createProgram(OpenCLKernelSources::brownian, defines, "");
5533
5534
5535
5536
5537
5538
5539
5540
5541
5542
5543
5544
    kernel1 = cl::Kernel(program, "integrateBrownianPart1");
    kernel2 = cl::Kernel(program, "integrateBrownianPart2");
    prevStepSize = -1.0;
}

void OpenCLIntegrateBrownianStepKernel::execute(ContextImpl& context, const BrownianIntegrator& integrator) {
    OpenCLIntegrationUtilities& integration = cl.getIntegrationUtilities();
    int numAtoms = cl.getNumAtoms();
    if (!hasInitializedKernels) {
        hasInitializedKernels = true;
        kernel1.setArg<cl::Buffer>(2, cl.getForce().getDeviceBuffer());
        kernel1.setArg<cl::Buffer>(3, integration.getPosDelta().getDeviceBuffer());
5545
5546
        kernel1.setArg<cl::Buffer>(4, cl.getVelm().getDeviceBuffer());
        kernel1.setArg<cl::Buffer>(5, integration.getRandom().getDeviceBuffer());
5547
        kernel2.setArg<cl::Buffer>(1, cl.getPosq().getDeviceBuffer());
5548
5549
5550
        setPosqCorrectionArg(cl, kernel2, 2);
        kernel2.setArg<cl::Buffer>(3, cl.getVelm().getDeviceBuffer());
        kernel2.setArg<cl::Buffer>(4, integration.getPosDelta().getDeviceBuffer());
5551
5552
5553
5554
5555
5556
    }
    double temperature = integrator.getTemperature();
    double friction = integrator.getFriction();
    double stepSize = integrator.getStepSize();
    if (temperature != prevTemp || friction != prevFriction || stepSize != prevStepSize) {
        double tau = (friction == 0.0 ? 0.0 : 1.0/friction);
5557
5558
5559
5560
5561
5562
5563
5564
5565
5566
        if (cl.getUseDoublePrecision() || cl.getUseMixedPrecision()) {
            kernel1.setArg<cl_double>(0, tau*stepSize);
            kernel1.setArg<cl_double>(1, sqrt(2.0f*BOLTZ*temperature*stepSize*tau));
            kernel2.setArg<cl_double>(0, 1.0/stepSize);
        }
        else {
            kernel1.setArg<cl_float>(0, (cl_float) (tau*stepSize));
            kernel1.setArg<cl_float>(1, (cl_float) (sqrt(2.0f*BOLTZ*temperature*stepSize*tau)));
            kernel2.setArg<cl_float>(0, (cl_float) (1.0/stepSize));
        }
5567
5568
5569
5570
5571
5572
5573
        prevTemp = temperature;
        prevFriction = friction;
        prevStepSize = stepSize;
    }

    // Call the first integration kernel.

5574
    kernel1.setArg<cl_uint>(6, integration.prepareRandomNumbers(cl.getPaddedNumAtoms()));
5575
5576
5577
5578
5579
5580
5581
5582
5583
    cl.executeKernel(kernel1, numAtoms);

    // Apply constraints.

    integration.applyConstraints(integrator.getConstraintTolerance());

    // Call the second integration kernel.

    cl.executeKernel(kernel2, numAtoms);
5584
    integration.computeVirtualSites();
5585
5586
5587
5588
5589

    // Update the time and step count.

    cl.setTime(cl.getTime()+stepSize);
    cl.setStepCount(cl.getStepCount()+1);
5590
    cl.reorderAtoms();
5591
5592
5593
5594
5595
5596
    
    // Reduce UI lag.
    
#ifdef WIN32
    cl.getQueue().flush();
#endif
5597
}
5598

5599
5600
5601
5602
double OpenCLIntegrateBrownianStepKernel::computeKineticEnergy(ContextImpl& context, const BrownianIntegrator& integrator) {
    return cl.getIntegrationUtilities().computeKineticEnergy(0);
}

5603
5604
5605
5606
OpenCLIntegrateVariableVerletStepKernel::~OpenCLIntegrateVariableVerletStepKernel() {
}

void OpenCLIntegrateVariableVerletStepKernel::initialize(const System& system, const VariableVerletIntegrator& integrator) {
5607
    cl.getPlatformData().initializeContexts(system);
5608
    cl::Program program = cl.createProgram(OpenCLKernelSources::verlet, "");
5609
5610
5611
    kernel1 = cl::Kernel(program, "integrateVerletPart1");
    kernel2 = cl::Kernel(program, "integrateVerletPart2");
    selectSizeKernel = cl::Kernel(program, "selectVerletStepSize");
5612
    blockSize = min(min(256, system.getNumParticles()), (int) selectSizeKernel.getWorkGroupInfo<CL_KERNEL_WORK_GROUP_SIZE>(cl.getDevice()));
5613
5614
}

5615
double OpenCLIntegrateVariableVerletStepKernel::execute(ContextImpl& context, const VariableVerletIntegrator& integrator, double maxTime) {
5616
    OpenCLIntegrationUtilities& integration = cl.getIntegrationUtilities();
5617
    int numAtoms = cl.getNumAtoms();
5618
    bool useDouble = cl.getUseDoublePrecision() || cl.getUseMixedPrecision();
5619
5620
5621
    if (!hasInitializedKernels) {
        hasInitializedKernels = true;
        kernel1.setArg<cl_int>(0, numAtoms);
5622
        kernel1.setArg<cl::Buffer>(1, cl.getIntegrationUtilities().getStepSize().getDeviceBuffer());
5623
        kernel1.setArg<cl::Buffer>(2, cl.getPosq().getDeviceBuffer());
5624
5625
5626
5627
        setPosqCorrectionArg(cl, kernel1, 3);
        kernel1.setArg<cl::Buffer>(4, cl.getVelm().getDeviceBuffer());
        kernel1.setArg<cl::Buffer>(5, cl.getForce().getDeviceBuffer());
        kernel1.setArg<cl::Buffer>(6, integration.getPosDelta().getDeviceBuffer());
5628
        kernel2.setArg<cl_int>(0, numAtoms);
5629
        kernel2.setArg<cl::Buffer>(1, cl.getIntegrationUtilities().getStepSize().getDeviceBuffer());
5630
        kernel2.setArg<cl::Buffer>(2, cl.getPosq().getDeviceBuffer());
5631
5632
5633
        setPosqCorrectionArg(cl, kernel2, 3);
        kernel2.setArg<cl::Buffer>(4, cl.getVelm().getDeviceBuffer());
        kernel2.setArg<cl::Buffer>(5, integration.getPosDelta().getDeviceBuffer());
5634
        selectSizeKernel.setArg<cl_int>(0, numAtoms);
5635
        selectSizeKernel.setArg<cl::Buffer>(3, cl.getIntegrationUtilities().getStepSize().getDeviceBuffer());
5636
5637
        selectSizeKernel.setArg<cl::Buffer>(4, cl.getVelm().getDeviceBuffer());
        selectSizeKernel.setArg<cl::Buffer>(5, cl.getForce().getDeviceBuffer());
5638
5639
        int elementSize = (useDouble ? sizeof(cl_double) : sizeof(cl_float));
        selectSizeKernel.setArg(6, blockSize*elementSize, NULL);
5640
5641
5642
5643
    }

    // Select the step size to use.

5644
5645
5646
5647
5648
5649
5650
5651
5652
5653
    double maxStepSize = maxTime-cl.getTime();
    float maxStepSizeFloat = (float) maxStepSize;
    if (useDouble) {
        selectSizeKernel.setArg<cl_double>(1, maxStepSize);
        selectSizeKernel.setArg<cl_double>(2, integrator.getErrorTolerance());
    }
    else {
        selectSizeKernel.setArg<cl_float>(1, maxStepSizeFloat);
        selectSizeKernel.setArg<cl_float>(2, (cl_float) integrator.getErrorTolerance());
    }
5654
5655
5656
5657
5658
5659
5660
5661
5662
5663
5664
5665
5666
    cl.executeKernel(selectSizeKernel, blockSize, blockSize);

    // Call the first integration kernel.

    cl.executeKernel(kernel1, numAtoms);

    // Apply constraints.

    integration.applyConstraints(integrator.getConstraintTolerance());

    // Call the second integration kernel.

    cl.executeKernel(kernel2, numAtoms);
5667
    integration.computeVirtualSites();
5668
5669
5670
5671
5672
5673
    
    // Reduce UI lag.
    
#ifdef WIN32
    cl.getQueue().flush();
#endif
5674
5675
5676

    // Update the time and step count.

5677
5678
5679
5680
5681
5682
5683
5684
5685
5686
5687
5688
5689
5690
5691
5692
5693
    double dt, time;
    if (useDouble) {
        mm_double2 stepSize;
        cl.getIntegrationUtilities().getStepSize().download(&stepSize);
        dt = stepSize.y;
        time = cl.getTime()+dt;
        if (dt == maxStepSize)
            time = maxTime; // Avoid round-off error
    }
    else {
        mm_float2 stepSize;
        cl.getIntegrationUtilities().getStepSize().download(&stepSize);
        dt = stepSize.y;
        time = cl.getTime()+dt;
        if (dt == maxStepSizeFloat)
            time = maxTime; // Avoid round-off error
    }
5694
5695
    cl.setTime(time);
    cl.setStepCount(cl.getStepCount()+1);
5696
    cl.reorderAtoms();
5697
    return dt;
5698
5699
}

5700
5701
5702
5703
double OpenCLIntegrateVariableVerletStepKernel::computeKineticEnergy(ContextImpl& context, const VariableVerletIntegrator& integrator) {
    return cl.getIntegrationUtilities().computeKineticEnergy(0.5*integrator.getStepSize());
}

5704
5705
5706
5707
5708
5709
OpenCLIntegrateVariableLangevinStepKernel::~OpenCLIntegrateVariableLangevinStepKernel() {
    if (params != NULL)
        delete params;
}

void OpenCLIntegrateVariableLangevinStepKernel::initialize(const System& system, const VariableLangevinIntegrator& integrator) {
5710
    cl.getPlatformData().initializeContexts(system);
5711
5712
    cl.getIntegrationUtilities().initRandomNumberGenerator(integrator.getRandomNumberSeed());
    map<string, string> defines;
5713
5714
    defines["NUM_ATOMS"] = cl.intToString(cl.getNumAtoms());
    defines["PADDED_NUM_ATOMS"] = cl.intToString(cl.getPaddedNumAtoms());
5715
    cl::Program program = cl.createProgram(OpenCLKernelSources::langevin, defines, "");
5716
5717
5718
    kernel1 = cl::Kernel(program, "integrateLangevinPart1");
    kernel2 = cl::Kernel(program, "integrateLangevinPart2");
    selectSizeKernel = cl::Kernel(program, "selectLangevinStepSize");
5719
    params = new OpenCLArray(cl, 3, cl.getUseDoublePrecision() || cl.getUseMixedPrecision() ? sizeof(cl_double) : sizeof(cl_float), "langevinParams");
Peter Eastman's avatar
Peter Eastman committed
5720
5721
    blockSize = min(256, system.getNumParticles());
    blockSize = max(blockSize, params->getSize());
5722
    blockSize = min(blockSize, (int) selectSizeKernel.getWorkGroupInfo<CL_KERNEL_WORK_GROUP_SIZE>(cl.getDevice()));
5723
5724
}

5725
double OpenCLIntegrateVariableLangevinStepKernel::execute(ContextImpl& context, const VariableLangevinIntegrator& integrator, double maxTime) {
5726
    OpenCLIntegrationUtilities& integration = cl.getIntegrationUtilities();
5727
    int numAtoms = cl.getNumAtoms();
5728
    bool useDouble = cl.getUseDoublePrecision() || cl.getUseMixedPrecision();
5729
5730
    if (!hasInitializedKernels) {
        hasInitializedKernels = true;
5731
5732
5733
5734
        kernel1.setArg<cl::Buffer>(0, cl.getVelm().getDeviceBuffer());
        kernel1.setArg<cl::Buffer>(1, cl.getForce().getDeviceBuffer());
        kernel1.setArg<cl::Buffer>(2, integration.getPosDelta().getDeviceBuffer());
        kernel1.setArg<cl::Buffer>(3, params->getDeviceBuffer());
5735
5736
5737
        kernel1.setArg<cl::Buffer>(4, integration.getStepSize().getDeviceBuffer());
        kernel1.setArg<cl::Buffer>(5, integration.getRandom().getDeviceBuffer());
        kernel2.setArg<cl::Buffer>(0, cl.getPosq().getDeviceBuffer());
5738
5739
5740
5741
        setPosqCorrectionArg(cl, kernel2, 1);
        kernel2.setArg<cl::Buffer>(2, integration.getPosDelta().getDeviceBuffer());
        kernel2.setArg<cl::Buffer>(3, cl.getVelm().getDeviceBuffer());
        kernel2.setArg<cl::Buffer>(4, integration.getStepSize().getDeviceBuffer());
5742
        selectSizeKernel.setArg<cl::Buffer>(4, integration.getStepSize().getDeviceBuffer());
5743
5744
5745
        selectSizeKernel.setArg<cl::Buffer>(5, cl.getVelm().getDeviceBuffer());
        selectSizeKernel.setArg<cl::Buffer>(6, cl.getForce().getDeviceBuffer());
        selectSizeKernel.setArg<cl::Buffer>(7, params->getDeviceBuffer());
5746
        int elementSize = (useDouble ? sizeof(cl_double) : sizeof(cl_float));
5747
5748
        selectSizeKernel.setArg(8, params->getSize()*elementSize, NULL);
        selectSizeKernel.setArg(9, blockSize*elementSize, NULL);
5749
5750
5751
5752
    }

    // Select the step size to use.

5753
5754
5755
5756
5757
5758
5759
5760
5761
5762
5763
5764
5765
5766
    double maxStepSize = maxTime-cl.getTime();
    float maxStepSizeFloat = (float) maxStepSize;
    if (useDouble) {
        selectSizeKernel.setArg<cl_double>(0, maxStepSize);
        selectSizeKernel.setArg<cl_double>(1, integrator.getErrorTolerance());
        selectSizeKernel.setArg<cl_double>(2, integrator.getFriction() == 0.0 ? 0.0 : 1.0/integrator.getFriction());
        selectSizeKernel.setArg<cl_double>(3, BOLTZ*integrator.getTemperature());
    }
    else {
        selectSizeKernel.setArg<cl_float>(0, maxStepSizeFloat);
        selectSizeKernel.setArg<cl_float>(1, (cl_float) integrator.getErrorTolerance());
        selectSizeKernel.setArg<cl_float>(2, (cl_float) (integrator.getFriction() == 0.0 ? 0.0 : 1.0/integrator.getFriction()));
        selectSizeKernel.setArg<cl_float>(3, (cl_float) (BOLTZ*integrator.getTemperature()));
    }
5767
5768
5769
5770
    cl.executeKernel(selectSizeKernel, blockSize, blockSize);

    // Call the first integration kernel.

5771
    kernel1.setArg<cl_uint>(6, integration.prepareRandomNumbers(cl.getPaddedNumAtoms()));
5772
5773
5774
5775
5776
5777
5778
5779
5780
    cl.executeKernel(kernel1, numAtoms);

    // Apply constraints.

    integration.applyConstraints(integrator.getConstraintTolerance());

    // Call the second integration kernel.

    cl.executeKernel(kernel2, numAtoms);
5781
    integration.computeVirtualSites();
5782
5783
5784
5785
5786
5787
    
    // Reduce UI lag.
    
#ifdef WIN32
    cl.getQueue().flush();
#endif
5788
5789
5790

    // Update the time and step count.

5791
5792
5793
5794
5795
5796
5797
5798
5799
5800
5801
5802
5803
5804
5805
5806
5807
    double dt, time;
    if (useDouble) {
        mm_double2 stepSize;
        cl.getIntegrationUtilities().getStepSize().download(&stepSize);
        dt = stepSize.y;
        time = cl.getTime()+dt;
        if (dt == maxStepSize)
            time = maxTime; // Avoid round-off error
    }
    else {
        mm_float2 stepSize;
        cl.getIntegrationUtilities().getStepSize().download(&stepSize);
        dt = stepSize.y;
        time = cl.getTime()+dt;
        if (dt == maxStepSizeFloat)
            time = maxTime; // Avoid round-off error
    }
5808
5809
    cl.setTime(time);
    cl.setStepCount(cl.getStepCount()+1);
5810
    cl.reorderAtoms();
5811
    return dt;
5812
5813
}

5814
5815
5816
5817
double OpenCLIntegrateVariableLangevinStepKernel::computeKineticEnergy(ContextImpl& context, const VariableLangevinIntegrator& integrator) {
    return cl.getIntegrationUtilities().computeKineticEnergy(0.5*integrator.getStepSize());
}

5818
5819
class OpenCLIntegrateCustomStepKernel::ReorderListener : public OpenCLContext::ReorderListener {
public:
5820
5821
    ReorderListener(OpenCLContext& cl, OpenCLParameterSet& perDofValues, vector<vector<cl_float> >& localPerDofValuesFloat, vector<vector<cl_double> >& localPerDofValuesDouble, bool& deviceValuesAreCurrent) :
            cl(cl), perDofValues(perDofValues), localPerDofValuesFloat(localPerDofValuesFloat), localPerDofValuesDouble(localPerDofValuesDouble), deviceValuesAreCurrent(deviceValuesAreCurrent) {
5822
5823
5824
5825
5826
5827
5828
5829
        int numAtoms = cl.getNumAtoms();
        lastAtomOrder.resize(numAtoms);
        for (int i = 0; i < numAtoms; i++)
            lastAtomOrder[i] = cl.getAtomIndex()[i];
    }
    void execute() {
        // Reorder the per-DOF variables to reflect the new atom order.

5830
5831
        if (perDofValues.getNumParameters() == 0)
            return;
5832
        int numAtoms = cl.getNumAtoms();
5833
5834
5835
5836
5837
5838
5839
5840
5841
5842
5843
5844
5845
5846
5847
5848
5849
5850
5851
5852
5853
5854
5855
5856
5857
5858
5859
5860
5861
5862
5863
5864
5865
        const vector<int>& order = cl.getAtomIndex();
        if (cl.getUseDoublePrecision() || cl.getUseMixedPrecision()) {
            if (deviceValuesAreCurrent)
                perDofValues.getParameterValues(localPerDofValuesDouble);
            vector<vector<cl_double> > swap(3*numAtoms);
            for (int i = 0; i < numAtoms; i++) {
                swap[3*lastAtomOrder[i]] = localPerDofValuesDouble[3*i];
                swap[3*lastAtomOrder[i]+1] = localPerDofValuesDouble[3*i+1];
                swap[3*lastAtomOrder[i]+2] = localPerDofValuesDouble[3*i+2];
            }
            for (int i = 0; i < numAtoms; i++) {
                localPerDofValuesDouble[3*i] = swap[3*order[i]];
                localPerDofValuesDouble[3*i+1] = swap[3*order[i]+1];
                localPerDofValuesDouble[3*i+2] = swap[3*order[i]+2];
            }
            perDofValues.setParameterValues(localPerDofValuesDouble);
        }
        else {
            if (deviceValuesAreCurrent)
                perDofValues.getParameterValues(localPerDofValuesFloat);
            vector<vector<cl_float> > swap(3*numAtoms);
            for (int i = 0; i < numAtoms; i++) {
                swap[3*lastAtomOrder[i]] = localPerDofValuesFloat[3*i];
                swap[3*lastAtomOrder[i]+1] = localPerDofValuesFloat[3*i+1];
                swap[3*lastAtomOrder[i]+2] = localPerDofValuesFloat[3*i+2];
            }
            for (int i = 0; i < numAtoms; i++) {
                localPerDofValuesFloat[3*i] = swap[3*order[i]];
                localPerDofValuesFloat[3*i+1] = swap[3*order[i]+1];
                localPerDofValuesFloat[3*i+2] = swap[3*order[i]+2];
            }
            perDofValues.setParameterValues(localPerDofValuesFloat);
        }
5866
5867
        for (int i = 0; i < numAtoms; i++)
            lastAtomOrder[i] = order[i];
Peter Eastman's avatar
Peter Eastman committed
5868
        deviceValuesAreCurrent = true;
5869
5870
5871
5872
    }
private:
    OpenCLContext& cl;
    OpenCLParameterSet& perDofValues;
5873
5874
    vector<vector<cl_float> >& localPerDofValuesFloat;
    vector<vector<cl_double> >& localPerDofValuesDouble;
Peter Eastman's avatar
Peter Eastman committed
5875
    bool& deviceValuesAreCurrent;
Peter Eastman's avatar
Peter Eastman committed
5876
    vector<int> lastAtomOrder;
5877
5878
};

5879
5880
5881
OpenCLIntegrateCustomStepKernel::~OpenCLIntegrateCustomStepKernel() {
    if (globalValues != NULL)
        delete globalValues;
5882
5883
5884
5885
    if (contextParameterValues != NULL)
        delete contextParameterValues;
    if (sumBuffer != NULL)
        delete sumBuffer;
5886
5887
5888
5889
    if (potentialEnergy != NULL)
        delete potentialEnergy;
    if (kineticEnergy != NULL)
        delete kineticEnergy;
5890
5891
5892
5893
    if (uniformRandoms != NULL)
        delete uniformRandoms;
    if (randomSeed != NULL)
        delete randomSeed;
5894
5895
    if (perDofValues != NULL)
        delete perDofValues;
5896
5897
    for (map<int, OpenCLArray*>::iterator iter = savedForces.begin(); iter != savedForces.end(); ++iter)
        delete iter->second;
5898
5899
5900
5901
5902
5903
}

void OpenCLIntegrateCustomStepKernel::initialize(const System& system, const CustomIntegrator& integrator) {
    cl.getPlatformData().initializeContexts(system);
    cl.getIntegrationUtilities().initRandomNumberGenerator(integrator.getRandomNumberSeed());
    numGlobalVariables = integrator.getNumGlobalVariables();
5904
5905
    int elementSize = (cl.getUseDoublePrecision() || cl.getUseMixedPrecision() ? sizeof(double) : sizeof(float));
    globalValues = new OpenCLArray(cl, max(1, numGlobalVariables), elementSize, "globalVariables");
5906
    sumBuffer = new OpenCLArray(cl, ((3*system.getNumParticles()+3)/4)*4, elementSize, "sumBuffer");
5907
    potentialEnergy = new OpenCLArray(cl, 1, cl.getEnergyBuffer().getElementSize(), "potentialEnergy");
5908
    kineticEnergy = new OpenCLArray(cl, 1, elementSize, "kineticEnergy");
5909
5910
    perDofValues = new OpenCLParameterSet(cl, integrator.getNumPerDofVariables(), 3*system.getNumParticles(), "perDofVariables", false, cl.getUseDoublePrecision() || cl.getUseMixedPrecision());
    cl.addReorderListener(new ReorderListener(cl, *perDofValues, localPerDofValuesFloat, localPerDofValuesDouble, deviceValuesAreCurrent));
5911
5912
5913
5914
    prevStepSize = -1.0;
    SimTKOpenMMUtilities::setRandomNumberSeed(integrator.getRandomNumberSeed());
}

5915
string OpenCLIntegrateCustomStepKernel::createGlobalComputation(const string& variable, const Lepton::ParsedExpression& expr, CustomIntegrator& integrator, const string& energyName) {
5916
5917
5918
5919
5920
5921
    map<string, Lepton::ParsedExpression> expressions;
    if (variable == "dt")
        expressions["dt[0].y = "] = expr;
    else {
        for (int i = 0; i < integrator.getNumGlobalVariables(); i++)
            if (variable == integrator.getGlobalVariableName(i))
5922
                expressions["globals["+cl.intToString(i)+"] = "] = expr;
5923
5924
        for (int i = 0; i < (int) parameterNames.size(); i++)
            if (variable == parameterNames[i]) {
5925
                expressions["params["+cl.intToString(i)+"] = "] = expr;
5926
5927
                modifiesParameters = true;
            }
5928
    }
5929
5930
    if (expressions.size() == 0)
        throw OpenMMException("Unknown global variable: "+variable);
5931
5932
5933
5934
    map<string, string> variables;
    variables["dt"] = "dt[0].y";
    variables["uniform"] = "uniform";
    variables["gaussian"] = "gaussian";
5935
    variables[energyName] = "energy[0]";
5936
    for (int i = 0; i < integrator.getNumGlobalVariables(); i++)
5937
        variables[integrator.getGlobalVariableName(i)] = "globals["+cl.intToString(i)+"]";
5938
    for (int i = 0; i < (int) parameterNames.size(); i++)
5939
        variables[parameterNames[i]] = "params["+cl.intToString(i)+"]";
peastman's avatar
peastman committed
5940
5941
    vector<const TabulatedFunction*> functions;
    vector<pair<string, string> > functionNames;
5942
    return cl.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp");
5943
5944
}

5945
string OpenCLIntegrateCustomStepKernel::createPerDofComputation(const string& variable, const Lepton::ParsedExpression& expr, int component, CustomIntegrator& integrator, const string& forceName, const string& energyName) {
5946
5947
5948
5949
5950
5951
5952
    const string suffixes[] = {".x", ".y", ".z"};
    string suffix = suffixes[component];
    map<string, Lepton::ParsedExpression> expressions;
    if (variable == "x")
        expressions["position"+suffix+" = "] = expr;
    else if (variable == "v")
        expressions["velocity"+suffix+" = "] = expr;
5953
    else if (variable == "")
5954
        expressions["sum[3*index+"+cl.intToString(component)+"] = "] = expr;
5955
5956
5957
5958
5959
    else {
        for (int i = 0; i < integrator.getNumPerDofVariables(); i++)
            if (variable == integrator.getPerDofVariableName(i))
                expressions["perDof"+suffix.substr(1)+perDofValues->getParameterSuffix(i)+" = "] = expr;
    }
5960
5961
    if (expressions.size() == 0)
        throw OpenMMException("Unknown per-DOF variable: "+variable);
5962
5963
5964
    map<string, string> variables;
    variables["x"] = "position"+suffix;
    variables["v"] = "velocity"+suffix;
5965
    variables[forceName] = "f"+suffix;
5966
    variables["gaussian"] = "gaussian"+suffix;
5967
    variables["uniform"] = "uniform"+suffix;
5968
5969
    variables["m"] = "mass";
    variables["dt"] = "stepSize";
5970
5971
    if (energyName != "")
        variables[energyName] = "energy[0]";
5972
    for (int i = 0; i < integrator.getNumGlobalVariables(); i++)
5973
        variables[integrator.getGlobalVariableName(i)] = "globals["+cl.intToString(i)+"]";
5974
5975
    for (int i = 0; i < integrator.getNumPerDofVariables(); i++)
        variables[integrator.getPerDofVariableName(i)] = "perDof"+suffix.substr(1)+perDofValues->getParameterSuffix(i);
5976
    for (int i = 0; i < (int) parameterNames.size(); i++)
5977
        variables[parameterNames[i]] = "params["+cl.intToString(i)+"]";
peastman's avatar
peastman committed
5978
5979
    vector<const TabulatedFunction*> functions;
    vector<pair<string, string> > functionNames;
5980
    string tempType = (cl.getSupportsDoublePrecision() ? "double" : "float");
5981
    return cl.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp"+cl.intToString(component)+"_", tempType);
5982
5983
}

5984
void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid) {
5985
5986
5987
    OpenCLIntegrationUtilities& integration = cl.getIntegrationUtilities();
    int numAtoms = cl.getNumAtoms();
    int numSteps = integrator.getNumComputations();
5988
    bool useDouble = cl.getUseDoublePrecision() || cl.getUseMixedPrecision();
5989
5990
    if (!hasInitializedKernels) {
        hasInitializedKernels = true;
5991
5992
5993
5994
        
        // Initialize various data structures.
        
        const map<string, double>& params = context.getParameters();
5995
5996
5997
5998
5999
6000
6001
6002
6003
6004
6005
6006
6007
6008
6009
6010
6011
        if (useDouble) {
            contextParameterValues = OpenCLArray::create<cl_double>(cl, max(1, (int) params.size()), "contextParameters");
            contextValuesDouble.resize(contextParameterValues->getSize());
            for (map<string, double>::const_iterator iter = params.begin(); iter != params.end(); ++iter) {
                contextValuesDouble[parameterNames.size()] = iter->second;
                parameterNames.push_back(iter->first);
            }
            contextParameterValues->upload(contextValuesDouble);
        }
        else {
            contextParameterValues = OpenCLArray::create<cl_float>(cl, max(1, (int) params.size()), "contextParameters");
            contextValuesFloat.resize(contextParameterValues->getSize());
            for (map<string, double>::const_iterator iter = params.begin(); iter != params.end(); ++iter) {
                contextValuesFloat[parameterNames.size()] = (float) iter->second;
                parameterNames.push_back(iter->first);
            }
            contextParameterValues->upload(contextValuesFloat);
6012
        }
6013
        kernels.resize(integrator.getNumComputations());
6014
6015
        requiredGaussian.resize(integrator.getNumComputations(), 0);
        requiredUniform.resize(integrator.getNumComputations(), 0);
6016
6017
        needsForces.resize(numSteps, false);
        needsEnergy.resize(numSteps, false);
6018
        forceGroup.resize(numSteps, -2);
6019
        invalidatesForces.resize(numSteps, false);
6020
        merged.resize(numSteps, false);
6021
        modifiesParameters = false;
6022
        map<string, string> defines;
6023
6024
        defines["NUM_ATOMS"] = cl.intToString(cl.getNumAtoms());
        defines["WORK_GROUP_SIZE"] = cl.intToString(OpenCLContext::ThreadBlockSize);
6025
        
6026
6027
6028
6029
6030
6031
6032
6033
6034
6035
6036
        // Build a list of all variables that affect the forces, so we can tell which
        // steps invalidate them.
        
        set<string> affectsForce;
        affectsForce.insert("x");
        for (vector<ForceImpl*>::const_iterator iter = context.getForceImpls().begin(); iter != context.getForceImpls().end(); ++iter) {
            const map<string, double> params = (*iter)->getDefaultParameters();
            for (map<string, double>::const_iterator param = params.begin(); param != params.end(); ++param)
                affectsForce.insert(param->first);
        }
        
6037
6038
6039
6040
6041
        // Record information about all the computation steps.
        
        stepType.resize(numSteps);
        vector<string> variable(numSteps);
        vector<Lepton::ParsedExpression> expression(numSteps);
6042
        vector<string> forceGroupName;
6043
        vector<string> energyGroupName;
6044
        for (int i = 0; i < 32; i++) {
6045
6046
6047
6048
6049
6050
            stringstream fname;
            fname << "f" << i;
            forceGroupName.push_back(fname.str());
            stringstream ename;
            ename << "energy" << i;
            energyGroupName.push_back(ename.str());
6051
6052
        }
        vector<string> forceName(numSteps, "f");
6053
        vector<string> energyName(numSteps, "energy");
6054
6055
6056
        for (int step = 0; step < numSteps; step++) {
            string expr;
            integrator.getComputationStep(step, stepType[step], variable[step], expr);
6057
            if (expr.size() > 0) {
6058
                expression[step] = Lepton::Parser::parse(expr).optimize();
6059
6060
6061
6062
6063
6064
6065
6066
6067
6068
6069
6070
6071
6072
6073
6074
                if (usesVariable(expression[step], "f")) {
                    needsForces[step] = true;
                    forceGroup[step] = -1;
                }
                if (usesVariable(expression[step], "energy")) {
                    needsEnergy[step] = true;
                    forceGroup[step] = -1;
                }
                for (int i = 0; i < 32; i++) {
                    if (usesVariable(expression[step], forceGroupName[i])) {
                        if (forceGroup[step] != -2)
                            throw OpenMMException("A single computation step cannot depend on multiple force groups");
                        needsForces[step] = true;
                        forceGroup[step] = 1<<i;
                        forceName[step] = forceGroupName[i];
                    }
6075
6076
6077
6078
6079
6080
6081
                    if (usesVariable(expression[step], energyGroupName[i])) {
                        if (forceGroup[step] != -2)
                            throw OpenMMException("A single computation step cannot depend on multiple force groups");
                        needsEnergy[step] = true;
                        forceGroup[step] = 1<<i;
                        energyName[step] = energyGroupName[i];
                    }
6082
                }
6083
6084
            }
            invalidatesForces[step] = (stepType[step] == CustomIntegrator::ConstrainPositions || affectsForce.find(variable[step]) != affectsForce.end());
6085
6086
            if (forceGroup[step] == -2 && step > 0)
                forceGroup[step] = forceGroup[step-1];
6087
6088
            if (forceGroup[step] != -2 && savedForces.find(forceGroup[step]) == savedForces.end())
                savedForces[forceGroup[step]] = new OpenCLArray(cl, cl.getForce().getSize(), cl.getForce().getElementSize(), "savedForces");
6089
6090
6091
6092
6093
6094
6095
6096
6097
6098
6099
6100
6101
6102
6103
6104
6105
6106
6107
6108
6109
6110
        }
        
        // Determine how each step will represent the position (as just a value, or a value plus a delta).
        
        vector<bool> storePosAsDelta(numSteps, false);
        vector<bool> loadPosAsDelta(numSteps, false);
        bool beforeConstrain = false;
        for (int step = numSteps-1; step >= 0; step--) {
            if (stepType[step] == CustomIntegrator::ConstrainPositions)
                beforeConstrain = true;
            else if (stepType[step] == CustomIntegrator::ComputePerDof && variable[step] == "x" && beforeConstrain)
                storePosAsDelta[step] = true;
        }
        bool storedAsDelta = false;
        for (int step = 0; step < numSteps; step++) {
            loadPosAsDelta[step] = storedAsDelta;
            if (storePosAsDelta[step] == true)
                storedAsDelta = true;
            if (stepType[step] == CustomIntegrator::ConstrainPositions)
                storedAsDelta = false;
        }
        
6111
6112
6113
        // Identify steps that can be merged into a single kernel.
        
        for (int step = 1; step < numSteps; step++) {
6114
            if (needsForces[step] || needsEnergy[step])
6115
                continue;
6116
6117
            if (stepType[step-1] == CustomIntegrator::ComputeGlobal && stepType[step] == CustomIntegrator::ComputeGlobal &&
                    !usesVariable(expression[step], "uniform") && !usesVariable(expression[step], "gaussian"))
6118
                merged[step] = true;
6119
            if (stepType[step-1] == CustomIntegrator::ComputePerDof && stepType[step] == CustomIntegrator::ComputePerDof)
6120
6121
6122
                merged[step] = true;
        }
        
6123
6124
6125
        // Loop over all steps and create the kernels for them.
        
        for (int step = 0; step < numSteps; step++) {
6126
            if ((stepType[step] == CustomIntegrator::ComputePerDof || stepType[step] == CustomIntegrator::ComputeSum) && !merged[step]) {
6127
6128
6129
6130
6131
                // Compute a per-DOF value.
                
                stringstream compute;
                for (int i = 0; i < (int) perDofValues->getBuffers().size(); i++) {
                    const OpenCLNonbondedUtilities::ParameterInfo& buffer = perDofValues->getBuffers()[i];
6132
6133
6134
                    compute << buffer.getType()<<" perDofx"<<cl.intToString(i+1)<<" = perDofValues"<<cl.intToString(i+1)<<"[3*index];\n";
                    compute << buffer.getType()<<" perDofy"<<cl.intToString(i+1)<<" = perDofValues"<<cl.intToString(i+1)<<"[3*index+1];\n";
                    compute << buffer.getType()<<" perDofz"<<cl.intToString(i+1)<<" = perDofValues"<<cl.intToString(i+1)<<"[3*index+2];\n";
6135
                }
6136
                int numGaussian = 0, numUniform = 0;
6137
                for (int j = step; j < numSteps && (j == step || merged[j]); j++) {
6138
6139
                    numGaussian += numAtoms*usesVariable(expression[j], "gaussian");
                    numUniform += numAtoms*usesVariable(expression[j], "uniform");
6140
                    compute << "{\n";
6141
                    if (numGaussian > 0)
6142
                        compute << "float4 gaussian = gaussianValues[gaussianIndex+index];\n";
6143
                    if (numUniform > 0)
6144
                        compute << "float4 uniform = uniformValues[uniformIndex+index];\n";
6145
                    for (int i = 0; i < 3; i++)
6146
                        compute << createPerDofComputation(stepType[j] == CustomIntegrator::ComputePerDof ? variable[j] : "", expression[j], i, integrator, forceName[j], energyName[j]);
6147
6148
6149
                    if (variable[j] == "x") {
                        if (storePosAsDelta[j]) {
                            if (cl.getSupportsDoublePrecision())
6150
                                compute << "posDelta[index] = convert_mixed4(convert_double4(position)-convert_double4(loadPos(posq, posqCorrection, index)));\n";
6151
6152
6153
                            else
                                compute << "posDelta[index] = position-posq[index];\n";
                        }
6154
                        else
6155
                            compute << "storePos(posq, posqCorrection, index, position);\n";
6156
                    }
6157
                    else if (variable[j] == "v")
6158
                        compute << "velm[index] = convert_mixed4(velocity);\n";
6159
6160
6161
                    else {
                        for (int i = 0; i < (int) perDofValues->getBuffers().size(); i++) {
                            const OpenCLNonbondedUtilities::ParameterInfo& buffer = perDofValues->getBuffers()[i];
6162
6163
6164
                            compute << "perDofValues"<<cl.intToString(i+1)<<"[3*index] = perDofx"<<cl.intToString(i+1)<<";\n";
                            compute << "perDofValues"<<cl.intToString(i+1)<<"[3*index+1] = perDofy"<<cl.intToString(i+1)<<";\n";
                            compute << "perDofValues"<<cl.intToString(i+1)<<"[3*index+2] = perDofz"<<cl.intToString(i+1)<<";\n";
6165
                        }
6166
                    }
6167
                    if (numGaussian > 0)
6168
                        compute << "gaussianIndex += NUM_ATOMS;\n";
6169
                    if (numUniform > 0)
6170
                        compute << "uniformIndex += NUM_ATOMS;\n";
6171
                    compute << "}\n";
6172
6173
6174
6175
6176
6177
                }
                map<string, string> replacements;
                replacements["COMPUTE_STEP"] = compute.str();
                stringstream args;
                for (int i = 0; i < (int) perDofValues->getBuffers().size(); i++) {
                    const OpenCLNonbondedUtilities::ParameterInfo& buffer = perDofValues->getBuffers()[i];
6178
                    string valueName = "perDofValues"+cl.intToString(i+1);
6179
6180
6181
                    args << ", __global " << buffer.getType() << "* restrict " << valueName;
                }
                replacements["PARAMETER_ARGUMENTS"] = args.str();
6182
6183
6184
6185
                if (loadPosAsDelta[step])
                    defines["LOAD_POS_AS_DELTA"] = "1";
                else if (defines.find("LOAD_POS_AS_DELTA") != defines.end())
                    defines.erase("LOAD_POS_AS_DELTA");
6186
6187
6188
                cl::Program program = cl.createProgram(cl.replaceStrings(OpenCLKernelSources::customIntegratorPerDof, replacements), defines);
                cl::Kernel kernel = cl::Kernel(program, "computePerDof");
                kernels[step].push_back(kernel);
6189
6190
                requiredGaussian[step] = numGaussian;
                requiredUniform[step] = numUniform;
6191
6192
                int index = 0;
                kernel.setArg<cl::Buffer>(index++, cl.getPosq().getDeviceBuffer());
6193
                setPosqCorrectionArg(cl, kernel, index++);
6194
6195
6196
6197
6198
                kernel.setArg<cl::Buffer>(index++, integration.getPosDelta().getDeviceBuffer());
                kernel.setArg<cl::Buffer>(index++, cl.getVelm().getDeviceBuffer());
                kernel.setArg<cl::Buffer>(index++, cl.getForce().getDeviceBuffer());
                kernel.setArg<cl::Buffer>(index++, integration.getStepSize().getDeviceBuffer());
                kernel.setArg<cl::Buffer>(index++, globalValues->getDeviceBuffer());
6199
6200
                kernel.setArg<cl::Buffer>(index++, contextParameterValues->getDeviceBuffer());
                kernel.setArg<cl::Buffer>(index++, sumBuffer->getDeviceBuffer());
6201
                index += 3;
6202
                kernel.setArg<cl::Buffer>(index++, potentialEnergy->getDeviceBuffer());
6203
6204
                for (int i = 0; i < (int) perDofValues->getBuffers().size(); i++)
                    kernel.setArg<cl::Memory>(index++, perDofValues->getBuffers()[i].getMemory());
6205
                if (stepType[step] == CustomIntegrator::ComputeSum) {
6206
6207
                    // Create a second kernel for this step that sums the values.

6208
                    program = cl.createProgram(OpenCLKernelSources::customIntegrator, defines);
6209
                    kernel = cl::Kernel(program, useDouble ? "computeDoubleSum" : "computeFloatSum");
6210
6211
6212
6213
6214
                    kernels[step].push_back(kernel);
                    index = 0;
                    kernel.setArg<cl::Buffer>(index++, sumBuffer->getDeviceBuffer());
                    bool found = false;
                    for (int j = 0; j < integrator.getNumGlobalVariables() && !found; j++)
6215
                        if (variable[step] == integrator.getGlobalVariableName(j)) {
6216
6217
6218
6219
6220
                            kernel.setArg<cl::Buffer>(index++, globalValues->getDeviceBuffer());
                            kernel.setArg<cl_uint>(index++, j);
                            found = true;
                        }
                    for (int j = 0; j < (int) parameterNames.size() && !found; j++)
6221
                        if (variable[step] == parameterNames[j]) {
6222
6223
6224
6225
6226
6227
                            kernel.setArg<cl::Buffer>(index++, contextParameterValues->getDeviceBuffer());
                            kernel.setArg<cl_uint>(index++, j);
                            found = true;
                            modifiesParameters = true;
                        }
                    if (!found)
6228
                        throw OpenMMException("Unknown global variable: "+variable[step]);
6229
                    kernel.setArg<cl_int>(index++, 3*numAtoms);
6230
                }
6231
            }
6232
            else if (stepType[step] == CustomIntegrator::ComputeGlobal && !merged[step]) {
6233
6234
6235
                // Compute a global value.

                stringstream compute;
6236
                for (int i = step; i < numSteps && (i == step || merged[i]); i++)
6237
                    compute << "{\n" << createGlobalComputation(variable[i], expression[i], integrator, energyName[i]) << "}\n";
6238
6239
6240
6241
6242
6243
6244
6245
                map<string, string> replacements;
                replacements["COMPUTE_STEP"] = compute.str();
                cl::Program program = cl.createProgram(cl.replaceStrings(OpenCLKernelSources::customIntegratorGlobal, replacements), defines);
                cl::Kernel kernel = cl::Kernel(program, "computeGlobal");
                kernels[step].push_back(kernel);
                int index = 0;
                kernel.setArg<cl::Buffer>(index++, integration.getStepSize().getDeviceBuffer());
                kernel.setArg<cl::Buffer>(index++, globalValues->getDeviceBuffer());
6246
                kernel.setArg<cl::Buffer>(index++, contextParameterValues->getDeviceBuffer());
6247
                index += 2;
6248
                kernel.setArg<cl::Buffer>(index++, potentialEnergy->getDeviceBuffer());
6249
            }
6250
6251
6252
6253
6254
6255
6256
6257
            else if (stepType[step] == CustomIntegrator::ConstrainPositions) {
                // Apply position constraints.

                cl::Program program = cl.createProgram(OpenCLKernelSources::customIntegrator, defines);
                cl::Kernel kernel = cl::Kernel(program, "applyPositionDeltas");
                kernels[step].push_back(kernel);
                int index = 0;
                kernel.setArg<cl::Buffer>(index++, cl.getPosq().getDeviceBuffer());
6258
                setPosqCorrectionArg(cl, kernel, index++);
6259
6260
                kernel.setArg<cl::Buffer>(index++, integration.getPosDelta().getDeviceBuffer());
            }
6261
        }
6262
        
6263
6264
6265
6266
6267
6268
6269
6270
        // Initialize the random number generator.
        
        int maxUniformRandoms = 1;
        for (int i = 0; i < (int) requiredUniform.size(); i++)
            maxUniformRandoms = max(maxUniformRandoms, requiredUniform[i]);
        uniformRandoms = OpenCLArray::create<mm_float4>(cl, maxUniformRandoms, "uniformRandoms");
        randomSeed = OpenCLArray::create<mm_int4>(cl, cl.getNumThreadBlocks()*OpenCLContext::ThreadBlockSize, "randomSeed");
        vector<mm_int4> seed(randomSeed->getSize());
6271
        int rseed = integrator.getRandomNumberSeed();
6272
        // A random seed of 0 means use a unique one
6273
6274
6275
        if (rseed == 0)
            rseed = osrngseed();
        unsigned int r = (unsigned int) (rseed+1);
6276
6277
6278
6279
6280
6281
6282
6283
6284
        for (int i = 0; i < randomSeed->getSize(); i++) {
            seed[i].x = r = (1664525*r + 1013904223) & 0xFFFFFFFF;
            seed[i].y = r = (1664525*r + 1013904223) & 0xFFFFFFFF;
            seed[i].z = r = (1664525*r + 1013904223) & 0xFFFFFFFF;
            seed[i].w = r = (1664525*r + 1013904223) & 0xFFFFFFFF;
        }
        randomSeed->upload(seed);
        cl::Program randomProgram = cl.createProgram(OpenCLKernelSources::customIntegrator, defines);
        randomKernel = cl::Kernel(randomProgram, "generateRandomNumbers");
6285
6286
6287
        randomKernel.setArg<cl_int>(0, maxUniformRandoms);
        randomKernel.setArg<cl::Buffer>(1, uniformRandoms->getDeviceBuffer());
        randomKernel.setArg<cl::Buffer>(2, randomSeed->getDeviceBuffer());
6288
        
6289
        // Create the kernel for summing the potential energy.
6290
6291

        cl::Program program = cl.createProgram(OpenCLKernelSources::customIntegrator, defines);
6292
        sumPotentialEnergyKernel = cl::Kernel(program, cl.getUseDoublePrecision() ? "computeDoubleSum" : "computeFloatSum");
6293
        int index = 0;
6294
6295
6296
6297
6298
6299
6300
6301
6302
6303
6304
6305
6306
6307
6308
6309
6310
6311
6312
6313
6314
6315
6316
6317
6318
6319
6320
6321
6322
6323
6324
6325
6326
6327
6328
6329
6330
6331
6332
6333
        sumPotentialEnergyKernel.setArg<cl::Buffer>(index++, cl.getEnergyBuffer().getDeviceBuffer());
        sumPotentialEnergyKernel.setArg<cl::Buffer>(index++, potentialEnergy->getDeviceBuffer());
        sumPotentialEnergyKernel.setArg<cl_int>(index++, 0);
        sumPotentialEnergyKernel.setArg<cl_int>(index++, cl.getEnergyBuffer().getSize());
        
        // Create the kernel for computing kinetic energy.

        stringstream computeKE;
        for (int i = 0; i < (int) perDofValues->getBuffers().size(); i++) {
            const OpenCLNonbondedUtilities::ParameterInfo& buffer = perDofValues->getBuffers()[i];
            computeKE << buffer.getType()<<" perDofx"<<cl.intToString(i+1)<<" = perDofValues"<<cl.intToString(i+1)<<"[3*index];\n";
            computeKE << buffer.getType()<<" perDofy"<<cl.intToString(i+1)<<" = perDofValues"<<cl.intToString(i+1)<<"[3*index+1];\n";
            computeKE << buffer.getType()<<" perDofz"<<cl.intToString(i+1)<<" = perDofValues"<<cl.intToString(i+1)<<"[3*index+2];\n";
        }
        Lepton::ParsedExpression keExpression = Lepton::Parser::parse(integrator.getKineticEnergyExpression()).optimize();
        for (int i = 0; i < 3; i++)
            computeKE << createPerDofComputation("", keExpression, i, integrator, "f", "");
        map<string, string> replacements;
        replacements["COMPUTE_STEP"] = computeKE.str();
        stringstream args;
        for (int i = 0; i < (int) perDofValues->getBuffers().size(); i++) {
            const OpenCLNonbondedUtilities::ParameterInfo& buffer = perDofValues->getBuffers()[i];
            string valueName = "perDofValues"+cl.intToString(i+1);
            args << ", __global " << buffer.getType() << "* restrict " << valueName;
        }
        replacements["PARAMETER_ARGUMENTS"] = args.str();
        if (defines.find("LOAD_POS_AS_DELTA") != defines.end())
            defines.erase("LOAD_POS_AS_DELTA");
        program = cl.createProgram(cl.replaceStrings(OpenCLKernelSources::customIntegratorPerDof, replacements), defines);
        kineticEnergyKernel = cl::Kernel(program, "computePerDof");
        index = 0;
        kineticEnergyKernel.setArg<cl::Buffer>(index++, cl.getPosq().getDeviceBuffer());
        setPosqCorrectionArg(cl, kineticEnergyKernel, index++);
        kineticEnergyKernel.setArg<cl::Buffer>(index++, integration.getPosDelta().getDeviceBuffer());
        kineticEnergyKernel.setArg<cl::Buffer>(index++, cl.getVelm().getDeviceBuffer());
        kineticEnergyKernel.setArg<cl::Buffer>(index++, cl.getForce().getDeviceBuffer());
        kineticEnergyKernel.setArg<cl::Buffer>(index++, integration.getStepSize().getDeviceBuffer());
        kineticEnergyKernel.setArg<cl::Buffer>(index++, globalValues->getDeviceBuffer());
        kineticEnergyKernel.setArg<cl::Buffer>(index++, contextParameterValues->getDeviceBuffer());
        kineticEnergyKernel.setArg<cl::Buffer>(index++, sumBuffer->getDeviceBuffer());
6334
        index += 2;
6335
6336
6337
6338
6339
6340
6341
6342
6343
6344
6345
6346
6347
6348
6349
        kineticEnergyKernel.setArg<cl::Buffer>(index++, uniformRandoms->getDeviceBuffer());
        kineticEnergyKernel.setArg<cl::Buffer>(index++, potentialEnergy->getDeviceBuffer());
        for (int i = 0; i < (int) perDofValues->getBuffers().size(); i++)
            kineticEnergyKernel.setArg<cl::Memory>(index++, perDofValues->getBuffers()[i].getMemory());
        keNeedsForce = usesVariable(keExpression, "f");

        // Create a second kernel to sum the values.

        program = cl.createProgram(OpenCLKernelSources::customIntegrator, defines);
        sumKineticEnergyKernel = cl::Kernel(program, useDouble ? "computeDoubleSum" : "computeFloatSum");
        index = 0;
        sumKineticEnergyKernel.setArg<cl::Buffer>(index++, sumBuffer->getDeviceBuffer());
        sumKineticEnergyKernel.setArg<cl::Buffer>(index++, kineticEnergy->getDeviceBuffer());
        sumKineticEnergyKernel.setArg<cl_int>(index++, 0);
        sumKineticEnergyKernel.setArg<cl_int>(index++, 3*numAtoms);
6350
    }
6351
6352
6353
    
    // Make sure all values (variables, parameters, etc.) stored on the device are up to date.
    
6354
    if (!deviceValuesAreCurrent) {
6355
6356
6357
6358
        if (useDouble)
            perDofValues->setParameterValues(localPerDofValuesDouble);
        else
            perDofValues->setParameterValues(localPerDofValuesFloat);
6359
6360
6361
6362
6363
        deviceValuesAreCurrent = true;
    }
    localValuesAreCurrent = false;
    double stepSize = integrator.getStepSize();
    if (stepSize != prevStepSize) {
6364
6365
6366
6367
6368
6369
6370
6371
        if (useDouble) {
            mm_double2 ss = mm_double2(0, stepSize);
            integration.getStepSize().upload(&ss);
        }
        else {
            mm_float2 ss = mm_float2(0, (float) stepSize);
            integration.getStepSize().upload(&ss);
        }
6372
6373
        prevStepSize = stepSize;
    }
6374
    bool paramsChanged = false;
6375
6376
6377
6378
6379
6380
6381
6382
6383
6384
6385
6386
6387
6388
6389
6390
6391
6392
    if (useDouble) {
        for (int i = 0; i < (int) parameterNames.size(); i++) {
            double value = context.getParameter(parameterNames[i]);
            if (value != contextValuesDouble[i]) {
                contextValuesDouble[i] = value;
                paramsChanged = true;
            }
        }
        if (paramsChanged)
            contextParameterValues->upload(contextValuesDouble);
    }
    else {
        for (int i = 0; i < (int) parameterNames.size(); i++) {
            float value = (float) context.getParameter(parameterNames[i]);
            if (value != contextValuesFloat[i]) {
                contextValuesFloat[i] = value;
                paramsChanged = true;
            }
6393
        }
6394
6395
        if (paramsChanged)
            contextParameterValues->upload(contextValuesFloat);
6396
    }
6397
}
6398

6399
6400
6401
6402
6403
6404
void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid) {
    prepareForComputation(context, integrator, forcesAreValid);
    OpenCLIntegrationUtilities& integration = cl.getIntegrationUtilities();
    int numAtoms = cl.getNumAtoms();
    int numSteps = integrator.getNumComputations();
    
6405
6406
6407
    // Loop over computation steps in the integrator and execute them.

    for (int i = 0; i < numSteps; i++) {
6408
6409
6410
6411
6412
6413
6414
6415
6416
6417
6418
6419
        int lastForceGroups = context.getLastForceGroups();
        if ((needsForces[i] || needsEnergy[i]) && (!forcesAreValid || lastForceGroups != forceGroup[i])) {
            if (forcesAreValid && savedForces.find(lastForceGroups) != savedForces.end()) {
                // The forces are still valid.  We just need a different force group right now.  Save the old
                // forces in case we need them again.
                
                cl.getForce().copyTo(*savedForces[lastForceGroups]);
                validSavedForces.insert(lastForceGroups);
            }
            else
                validSavedForces.clear();
            
6420
6421
6422
6423
6424
6425
6426
6427
6428
6429
6430
6431
6432
6433
6434
6435
            // Recompute forces and/or energy.  Figure out what is actually needed
            // between now and the next time they get invalidated again.
            
            bool computeForce = false, computeEnergy = false;
            for (int j = i; ; j++) {
                if (needsForces[j])
                    computeForce = true;
                if (needsEnergy[j])
                    computeEnergy = true;
                if (invalidatesForces[j])
                    break;
                if (j == numSteps-1)
                    j = -1;
                if (j == i-1)
                    break;
            }
6436
6437
6438
6439
6440
6441
6442
6443
6444
6445
            if (!computeEnergy && validSavedForces.find(forceGroup[i]) != validSavedForces.end()) {
                // We can just restore the forces we saved earlier.
                
                savedForces[forceGroup[i]]->copyTo(cl.getForce());
            }
            else {
                recordChangedParameters(context);
                context.calcForcesAndEnergy(computeForce, computeEnergy, forceGroup[i]);
                if (computeEnergy)
                    cl.executeKernel(sumPotentialEnergyKernel, OpenCLContext::ThreadBlockSize, OpenCLContext::ThreadBlockSize);
6446
            forcesAreValid = true;
6447
            }
6448
        }
6449
        if (stepType[i] == CustomIntegrator::ComputePerDof && !merged[i]) {
6450
            kernels[i][0].setArg<cl_uint>(10, integration.prepareRandomNumbers(requiredGaussian[i]));
6451
6452
            kernels[i][0].setArg<cl::Buffer>(9, integration.getRandom().getDeviceBuffer());
            kernels[i][0].setArg<cl::Buffer>(11, uniformRandoms->getDeviceBuffer());
6453
6454
            if (requiredUniform[i] > 0)
                cl.executeKernel(randomKernel, numAtoms);
6455
6456
            cl.executeKernel(kernels[i][0], numAtoms);
        }
6457
        else if (stepType[i] == CustomIntegrator::ComputeGlobal && !merged[i]) {
6458
6459
            kernels[i][0].setArg<cl_float>(3, (cl_float) SimTKOpenMMUtilities::getUniformlyDistributedRandomNumber());
            kernels[i][0].setArg<cl_float>(4, (cl_float) SimTKOpenMMUtilities::getNormallyDistributedRandomNumber());
6460
            cl.executeKernel(kernels[i][0], 1, 1);
6461
        }
6462
        else if (stepType[i] == CustomIntegrator::ComputeSum) {
6463
            kernels[i][0].setArg<cl_uint>(10, integration.prepareRandomNumbers(requiredGaussian[i]));
6464
6465
            kernels[i][0].setArg<cl::Buffer>(9, integration.getRandom().getDeviceBuffer());
            kernels[i][0].setArg<cl::Buffer>(11, uniformRandoms->getDeviceBuffer());
6466
6467
            if (requiredUniform[i] > 0)
                cl.executeKernel(randomKernel, numAtoms);
6468
            cl.clearBuffer(*sumBuffer);
6469
6470
6471
6472
6473
            cl.executeKernel(kernels[i][0], numAtoms);
            cl.executeKernel(kernels[i][1], OpenCLContext::ThreadBlockSize, OpenCLContext::ThreadBlockSize);
        }
        else if (stepType[i] == CustomIntegrator::UpdateContextState) {
            recordChangedParameters(context);
6474
            context.updateContextState();
6475
        }
6476
6477
6478
        else if (stepType[i] == CustomIntegrator::ConstrainPositions) {
            cl.getIntegrationUtilities().applyConstraints(integrator.getConstraintTolerance());
            cl.executeKernel(kernels[i][0], numAtoms);
6479
            cl.getIntegrationUtilities().computeVirtualSites();
6480
        }
6481
6482
6483
        else if (stepType[i] == CustomIntegrator::ConstrainVelocities) {
            cl.getIntegrationUtilities().applyVelocityConstraints(integrator.getConstraintTolerance());
        }
6484
6485
6486
        if (invalidatesForces[i])
            forcesAreValid = false;
    }
6487
    recordChangedParameters(context);
6488
6489
6490

    // Update the time and step count.

6491
    cl.setTime(cl.getTime()+integrator.getStepSize());
6492
    cl.setStepCount(cl.getStepCount()+1);
6493
    cl.reorderAtoms();
6494
6495
6496
6497
    if (cl.getAtomsWereReordered()) {
        forcesAreValid = false;
        validSavedForces.clear();
    }
6498
6499
6500
6501
6502
6503
    
    // Reduce UI lag.
    
#ifdef WIN32
    cl.getQueue().flush();
#endif
6504
6505
}

6506
6507
6508
6509
6510
6511
6512
6513
6514
6515
6516
6517
6518
6519
double OpenCLIntegrateCustomStepKernel::computeKineticEnergy(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid) {
    prepareForComputation(context, integrator, forcesAreValid);
    if (keNeedsForce && !forcesAreValid) {
        // Compute the force.  We want to then mark that forces are valid, which means also computing
        // potential energy if any steps will expect it to be valid too.
        
        bool willNeedEnergy = false;
        for (int i = 0; i < integrator.getNumComputations(); i++)
            willNeedEnergy |= needsEnergy[i];
        context.calcForcesAndEnergy(true, willNeedEnergy, -1);
        if (willNeedEnergy)
            cl.executeKernel(sumPotentialEnergyKernel, OpenCLContext::ThreadBlockSize, OpenCLContext::ThreadBlockSize);
        forcesAreValid = true;
    }
6520
    cl.clearBuffer(*sumBuffer);
6521
6522
    kineticEnergyKernel.setArg<cl::Buffer>(9, cl.getIntegrationUtilities().getRandom().getDeviceBuffer());
    kineticEnergyKernel.setArg<cl_uint>(10, 0);
6523
6524
6525
6526
6527
6528
6529
6530
6531
6532
6533
6534
6535
6536
    cl.executeKernel(kineticEnergyKernel, cl.getNumAtoms());
    cl.executeKernel(sumKineticEnergyKernel, OpenCLContext::ThreadBlockSize, OpenCLContext::ThreadBlockSize);
    if (cl.getUseDoublePrecision() || cl.getUseMixedPrecision()) {
        double ke;
        kineticEnergy->download(&ke);
        return ke;
    }
    else {
        float ke;
        kineticEnergy->download(&ke);
        return ke;
    }
}

6537
6538
6539
void OpenCLIntegrateCustomStepKernel::recordChangedParameters(ContextImpl& context) {
    if (!modifiesParameters)
        return;
6540
6541
6542
6543
6544
6545
6546
6547
6548
6549
6550
6551
6552
6553
6554
    if (cl.getUseDoublePrecision() || cl.getUseMixedPrecision()) {
        contextParameterValues->download(contextValuesDouble);
        for (int i = 0; i < (int) parameterNames.size(); i++) {
            double value = context.getParameter(parameterNames[i]);
            if (value != contextValuesDouble[i])
                context.setParameter(parameterNames[i], contextValuesDouble[i]);
        }
    }
    else {
        contextParameterValues->download(contextValuesFloat);
        for (int i = 0; i < (int) parameterNames.size(); i++) {
            float value = (float) context.getParameter(parameterNames[i]);
            if (value != contextValuesFloat[i])
                context.setParameter(parameterNames[i], contextValuesFloat[i]);
        }
6555
6556
6557
    }
}

6558
void OpenCLIntegrateCustomStepKernel::getGlobalVariables(ContextImpl& context, vector<double>& values) const {
6559
6560
6561
6562
    if (numGlobalVariables == 0) {
        values.resize(0);
        return;
    }
6563
6564
6565
6566
6567
6568
6569
6570
    if (cl.getUseDoublePrecision() || cl.getUseMixedPrecision())
        globalValues->download(values);
    else {
        vector<cl_float> buffer;
        globalValues->download(buffer);
        for (int i = 0; i < numGlobalVariables; i++)
            values[i] = buffer[i];
    }
6571
6572
6573
}

void OpenCLIntegrateCustomStepKernel::setGlobalVariables(ContextImpl& context, const vector<double>& values) {
6574
6575
    if (numGlobalVariables == 0)
        return;
6576
6577
6578
6579
6580
6581
6582
6583
    if (cl.getUseDoublePrecision() || cl.getUseMixedPrecision())
        globalValues->upload(values);
    else {
        vector<cl_float> buffer(numGlobalVariables);
        for (int i = 0; i < numGlobalVariables; i++)
            buffer[i] = (cl_float) values[i];
        globalValues->upload(buffer);
    }
6584
6585
6586
6587
}

void OpenCLIntegrateCustomStepKernel::getPerDofVariable(ContextImpl& context, int variable, vector<Vec3>& values) const {
    values.resize(perDofValues->getNumObjects()/3);
6588
6589
6590
6591
6592
6593
6594
6595
6596
6597
6598
6599
6600
6601
6602
6603
6604
6605
6606
    const vector<int>& order = cl.getAtomIndex();
    if (cl.getUseDoublePrecision() || cl.getUseMixedPrecision()) {
        if (!localValuesAreCurrent) {
            perDofValues->getParameterValues(localPerDofValuesDouble);
            localValuesAreCurrent = true;
        }
        for (int i = 0; i < (int) values.size(); i++)
            for (int j = 0; j < 3; j++)
                values[order[i]][j] = localPerDofValuesDouble[3*i+j][variable];
    }
    else {
        if (!localValuesAreCurrent) {
            perDofValues->getParameterValues(localPerDofValuesFloat);
            localValuesAreCurrent = true;
        }
        for (int i = 0; i < (int) values.size(); i++)
            for (int j = 0; j < 3; j++)
                values[order[i]][j] = localPerDofValuesFloat[3*i+j][variable];
    }
6607
6608
6609
}

void OpenCLIntegrateCustomStepKernel::setPerDofVariable(ContextImpl& context, int variable, const vector<Vec3>& values) {
6610
6611
6612
6613
6614
6615
6616
6617
6618
6619
6620
6621
6622
6623
6624
6625
6626
6627
    const vector<int>& order = cl.getAtomIndex();
    if (cl.getUseDoublePrecision() || cl.getUseMixedPrecision()) {
        if (!localValuesAreCurrent) {
            perDofValues->getParameterValues(localPerDofValuesDouble);
            localValuesAreCurrent = true;
        }
        for (int i = 0; i < (int) values.size(); i++)
            for (int j = 0; j < 3; j++)
                localPerDofValuesDouble[3*i+j][variable] = values[order[i]][j];
    }
    else {
        if (!localValuesAreCurrent) {
            perDofValues->getParameterValues(localPerDofValuesFloat);
            localValuesAreCurrent = true;
        }
        for (int i = 0; i < (int) values.size(); i++)
            for (int j = 0; j < 3; j++)
                localPerDofValuesFloat[3*i+j][variable] = (float) values[order[i]][j];
6628
6629
6630
6631
    }
    deviceValuesAreCurrent = false;
}

6632
OpenCLApplyAndersenThermostatKernel::~OpenCLApplyAndersenThermostatKernel() {
6633
6634
    if (atomGroups != NULL)
        delete atomGroups;
6635
6636
6637
6638
6639
}

void OpenCLApplyAndersenThermostatKernel::initialize(const System& system, const AndersenThermostat& thermostat) {
    randomSeed = thermostat.getRandomNumberSeed();
    map<string, string> defines;
6640
    defines["NUM_ATOMS"] = cl.intToString(cl.getNumAtoms());
6641
    cl::Program program = cl.createProgram(OpenCLKernelSources::andersenThermostat, defines);
6642
    kernel = cl::Kernel(program, "applyAndersenThermostat");
Peter Eastman's avatar
Peter Eastman committed
6643
    cl.getIntegrationUtilities().initRandomNumberGenerator(randomSeed);
6644
6645
6646
6647

    // Create the arrays with the group definitions.

    vector<vector<int> > groups = AndersenThermostatImpl::calcParticleGroups(system);
6648
    atomGroups = OpenCLArray::create<int>(cl, cl.getNumAtoms(), "atomGroups");
6649
6650
6651
6652
6653
6654
    vector<int> atoms(atomGroups->getSize());
    for (int i = 0; i < (int) groups.size(); i++) {
        for (int j = 0; j < (int) groups[i].size(); j++)
            atoms[groups[i][j]] = i;
    }
    atomGroups->upload(atoms);
6655
6656
6657
6658
6659
6660
6661
6662
}

void OpenCLApplyAndersenThermostatKernel::execute(ContextImpl& context) {
    if (!hasInitializedKernels) {
        hasInitializedKernels = true;
        kernel.setArg<cl::Buffer>(2, cl.getVelm().getDeviceBuffer());
        kernel.setArg<cl::Buffer>(3, cl.getIntegrationUtilities().getStepSize().getDeviceBuffer());
        kernel.setArg<cl::Buffer>(4, cl.getIntegrationUtilities().getRandom().getDeviceBuffer());
6663
        kernel.setArg<cl::Buffer>(6, atomGroups->getDeviceBuffer());
6664
6665
6666
6667
6668
6669
6670
    }
    kernel.setArg<cl_float>(0, (cl_float) context.getParameter(AndersenThermostat::CollisionFrequency()));
    kernel.setArg<cl_float>(1, (cl_float) (BOLTZ*context.getParameter(AndersenThermostat::Temperature())));
    kernel.setArg<cl_uint>(5, cl.getIntegrationUtilities().prepareRandomNumbers(cl.getPaddedNumAtoms()));
    cl.executeKernel(kernel, cl.getNumAtoms());
}

6671
6672
6673
6674
6675
6676
6677
6678
6679
OpenCLApplyMonteCarloBarostatKernel::~OpenCLApplyMonteCarloBarostatKernel() {
    if (savedPositions != NULL)
        delete savedPositions;
    if (moleculeAtoms != NULL)
        delete moleculeAtoms;
    if (moleculeStartIndex != NULL)
        delete moleculeStartIndex;
}

6680
void OpenCLApplyMonteCarloBarostatKernel::initialize(const System& system, const Force& thermostat) {
6681
    savedPositions = new OpenCLArray(cl, cl.getPaddedNumAtoms(), cl.getUseDoublePrecision() ? sizeof(mm_double4) : sizeof(mm_float4), "savedPositions");
6682
    cl::Program program = cl.createProgram(OpenCLKernelSources::monteCarloBarostat);
6683
    kernel = cl::Kernel(program, "scalePositions");
6684
6685
}

6686
void OpenCLApplyMonteCarloBarostatKernel::scaleCoordinates(ContextImpl& context, double scaleX, double scaleY, double scaleZ) {
6687
6688
6689
6690
6691
6692
6693
6694
6695
6696
6697
6698
6699
6700
6701
6702
6703
6704
6705
6706
6707
6708
6709
6710
    if (!hasInitializedKernels) {
        hasInitializedKernels = true;

        // Create the arrays with the molecule definitions.

        vector<vector<int> > molecules = context.getMolecules();
        numMolecules = molecules.size();
        moleculeAtoms = OpenCLArray::create<int>(cl, cl.getNumAtoms(), "moleculeAtoms");
        moleculeStartIndex = OpenCLArray::create<int>(cl, numMolecules+1, "moleculeStartIndex");
        vector<int> atoms(moleculeAtoms->getSize());
        vector<int> startIndex(moleculeStartIndex->getSize());
        int index = 0;
        for (int i = 0; i < numMolecules; i++) {
            startIndex[i] = index;
            for (int j = 0; j < (int) molecules[i].size(); j++)
                atoms[index++] = molecules[i][j];
        }
        startIndex[numMolecules] = index;
        moleculeAtoms->upload(atoms);
        moleculeStartIndex->upload(startIndex);

        // Initialize the kernel arguments.
        
        kernel.setArg<cl_int>(3, numMolecules);
6711
6712
6713
        kernel.setArg<cl::Buffer>(9, cl.getPosq().getDeviceBuffer());
        kernel.setArg<cl::Buffer>(10, moleculeAtoms->getDeviceBuffer());
        kernel.setArg<cl::Buffer>(11, moleculeStartIndex->getDeviceBuffer());
6714
    }
6715
6716
    int bytesToCopy = cl.getPosq().getSize()*(cl.getUseDoublePrecision() ? sizeof(mm_double4) : sizeof(mm_float4));
    cl.getQueue().enqueueCopyBuffer(cl.getPosq().getDeviceBuffer(), savedPositions->getDeviceBuffer(), 0, 0, bytesToCopy);
6717
6718
6719
    kernel.setArg<cl_float>(0, (cl_float) scaleX);
    kernel.setArg<cl_float>(1, (cl_float) scaleY);
    kernel.setArg<cl_float>(2, (cl_float) scaleZ);
6720
    setPeriodicBoxArgs(cl, kernel, 4);
6721
    cl.executeKernel(kernel, cl.getNumAtoms());
6722
6723
    for (int i = 0; i < (int) cl.getPosCellOffsets().size(); i++)
        cl.getPosCellOffsets()[i] = mm_int4(0, 0, 0, 0);
6724
    lastAtomOrder = cl.getAtomIndex();
6725
6726
6727
}

void OpenCLApplyMonteCarloBarostatKernel::restoreCoordinates(ContextImpl& context) {
6728
6729
    int bytesToCopy = cl.getPosq().getSize()*(cl.getUseDoublePrecision() ? sizeof(mm_double4) : sizeof(mm_float4));
    cl.getQueue().enqueueCopyBuffer(savedPositions->getDeviceBuffer(), cl.getPosq().getDeviceBuffer(), 0, 0, bytesToCopy);
6730
6731
}

6732
6733
6734
6735
6736
6737
6738
6739
OpenCLRemoveCMMotionKernel::~OpenCLRemoveCMMotionKernel() {
    if (cmMomentum != NULL)
        delete cmMomentum;
}

void OpenCLRemoveCMMotionKernel::initialize(const System& system, const CMMotionRemover& force) {
    frequency = force.getFrequency();
    int numAtoms = cl.getNumAtoms();
6740
    cmMomentum = OpenCLArray::create<mm_float4>(cl, (numAtoms+OpenCLContext::ThreadBlockSize-1)/OpenCLContext::ThreadBlockSize, "cmMomentum");
6741
6742
6743
6744
    double totalMass = 0.0;
    for (int i = 0; i < numAtoms; i++)
        totalMass += system.getParticleMass(i);
    map<string, string> defines;
6745
    defines["INVERSE_TOTAL_MASS"] = cl.doubleToString(totalMass == 0 ? 0.0 : 1.0/totalMass);
6746
    cl::Program program = cl.createProgram(OpenCLKernelSources::removeCM, defines);
6747
6748
6749
6750
6751
6752
6753
6754
6755
6756
6757
6758
6759
6760
6761
6762
    kernel1 = cl::Kernel(program, "calcCenterOfMassMomentum");
    kernel1.setArg<cl_int>(0, numAtoms);
    kernel1.setArg<cl::Buffer>(1, cl.getVelm().getDeviceBuffer());
    kernel1.setArg<cl::Buffer>(2, cmMomentum->getDeviceBuffer());
    kernel1.setArg(3, OpenCLContext::ThreadBlockSize*sizeof(mm_float4), NULL);
    kernel2 = cl::Kernel(program, "removeCenterOfMassMomentum");
    kernel2.setArg<cl_int>(0, numAtoms);
    kernel2.setArg<cl::Buffer>(1, cl.getVelm().getDeviceBuffer());
    kernel2.setArg<cl::Buffer>(2, cmMomentum->getDeviceBuffer());
    kernel2.setArg(3, OpenCLContext::ThreadBlockSize*sizeof(mm_float4), NULL);
}

void OpenCLRemoveCMMotionKernel::execute(ContextImpl& context) {
    cl.executeKernel(kernel1, cl.getNumAtoms());
    cl.executeKernel(kernel2, cl.getNumAtoms());
}