"platforms/opencl/tests/TestOpenCLNonbondedForce.cpp" did not exist on "64be7b6857b0e76ee222c1038ebb190ad6e677f9"
CommonDrudeKernels.cpp 23.8 KB
Newer Older
1
2
3
4
5
6
7
8
/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
 * This is part of the OpenMM molecular simulation toolkit originating from   *
 * Simbios, the NIH National Center for Physics-Based Simulation of           *
 * Biological Structures at Stanford, funded under the NIH Roadmap for        *
 * Medical Research, grant U54 GM072970. See https://simtk.org.               *
 *                                                                            *
9
 * Portions copyright (c) 2013-2024 Stanford University and the Authors.      *
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
 * Authors: Peter Eastman                                                     *
 * Contributors:                                                              *
 *                                                                            *
 * Permission is hereby granted, free of charge, to any person obtaining a    *
 * copy of this software and associated documentation files (the "Software"), *
 * to deal in the Software without restriction, including without limitation  *
 * the rights to use, copy, modify, merge, publish, distribute, sublicense,   *
 * and/or sell copies of the Software, and to permit persons to whom the      *
 * Software is furnished to do so, subject to the following conditions:       *
 *                                                                            *
 * The above copyright notice and this permission notice shall be included in *
 * all copies or substantial portions of the Software.                        *
 *                                                                            *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,   *
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL    *
 * THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,    *
 * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR      *
 * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE  *
 * USE OR OTHER DEALINGS IN THE SOFTWARE.                                     *
 * -------------------------------------------------------------------------- */

32
33
#include "CommonDrudeKernels.h"
#include "CommonDrudeKernelSources.h"
34
#include "openmm/internal/ContextImpl.h"
35
36
#include "openmm/common/BondedUtilities.h"
#include "openmm/common/ComputeForceInfo.h"
37
#include "openmm/common/ContextSelector.h"
38
39
#include "openmm/common/IntegrationUtilities.h"
#include "CommonKernelSources.h"
40
#include "SimTKOpenMMRealType.h"
41
42
43
44
45
#include <set>

using namespace OpenMM;
using namespace std;

46
class CommonDrudeForceInfo : public ComputeForceInfo {
47
public:
48
    CommonDrudeForceInfo(const DrudeForce& force) : force(force) {
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    }
    int getNumParticleGroups() {
        return force.getNumParticles()+force.getNumScreenedPairs();
    }
    void getParticlesInGroup(int index, vector<int>& particles) {
        particles.clear();
        if (index < force.getNumParticles()) {
            int p, p1, p2, p3, p4;
            double charge, polarizability, aniso12, aniso34;
            force.getParticleParameters(index, p, p1, p2, p3, p4, charge, polarizability, aniso12, aniso34);
            particles.push_back(p);
            particles.push_back(p1);
            if (p2 != -1)
                particles.push_back(p2);
            if (p3 != -1)
                particles.push_back(p3);
            if (p4 != -1)
                particles.push_back(p4);
        }
        else {
            int drude1, drude2;
            double thole;
            force.getScreenedPairParameters(index-force.getNumParticles(), drude1, drude2, thole);
            int p, p1, p2, p3, p4;
            double charge, polarizability, aniso12, aniso34;
            force.getParticleParameters(drude1, p, p1, p2, p3, p4, charge, polarizability, aniso12, aniso34);
            particles.push_back(p);
            particles.push_back(p1);
            force.getParticleParameters(drude2, p, p1, p2, p3, p4, charge, polarizability, aniso12, aniso34);
            particles.push_back(p);
            particles.push_back(p1);
        }
    }
    bool areGroupsIdentical(int group1, int group2) {
        if (group1 < force.getNumParticles() && group2 < force.getNumParticles()) {
            int p, p1, p2, p3, p4;
            double charge1, polarizability1, aniso12_1, aniso34_1;
            double charge2, polarizability2, aniso12_2, aniso34_2;
            force.getParticleParameters(group1, p, p1, p2, p3, p4, charge1, polarizability1, aniso12_1, aniso34_1);
            force.getParticleParameters(group2, p, p1, p2, p3, p4, charge2, polarizability2, aniso12_2, aniso34_2);
            return (charge1 == charge2 && polarizability1 == polarizability2 && aniso12_1 == aniso12_2 && aniso34_1 == aniso34_2);
        }
        if (group1 >= force.getNumParticles() && group2 >= force.getNumParticles()) {
            int drude1, drude2;
            double thole1, thole2;
            force.getScreenedPairParameters(group1-force.getNumParticles(), drude1, drude2, thole1);
            force.getScreenedPairParameters(group1-force.getNumParticles(), drude1, drude2, thole2);
            return (thole1 == thole2);
        }
        return false;
    }
private:
    const DrudeForce& force;
};

104
105
void CommonCalcDrudeForceKernel::initialize(const System& system, const DrudeForce& force) {
    if (cc.getContextIndex() != 0)
106
        return; // This is run entirely on one device
107
    ContextSelector selector(cc);
108
    int numParticles = force.getNumParticles();
109
110
111
112
    if (numParticles > 0) {
        // Create the harmonic interaction .
        
        vector<vector<int> > atoms(numParticles, vector<int>(5));
113
        particleParams.initialize<mm_float4>(cc, numParticles, "drudeParticleParams");
114
115
116
        vector<mm_float4> paramVector(numParticles);
        for (int i = 0; i < numParticles; i++) {
            double charge, polarizability, aniso12, aniso34;
117
            force.getParticleParameters(i, atoms[i][0], atoms[i][1], atoms[i][2], atoms[i][3], atoms[i][4], charge, polarizability, aniso12, aniso34);
118
119
120
            double a1 = (atoms[i][2] == -1 ? 1 : aniso12);
            double a2 = (atoms[i][3] == -1 || atoms[i][4] == -1 ? 1 : aniso34);
            double a3 = 3-a1-a2;
121
122
123
            double k3 = ONE_4PI_EPS0*charge*charge/(polarizability*a3);
            double k1 = ONE_4PI_EPS0*charge*charge/(polarizability*a1) - k3;
            double k2 = ONE_4PI_EPS0*charge*charge/(polarizability*a2) - k3;
124
            if (atoms[i][2] == -1) {
125
                atoms[i][2] = atoms[i][0];
126
127
128
                k1 = 0;
            }
            if (atoms[i][3] == -1 || atoms[i][4] == -1) {
129
130
                atoms[i][3] = atoms[i][0];
                atoms[i][4] = atoms[i][0];
131
132
133
134
                k2 = 0;
            }
            paramVector[i] = mm_float4((float) k1, (float) k2, (float) k3, 0.0f);
        }
peastman's avatar
peastman committed
135
        particleParams.upload(paramVector);
136
        map<string, string> replacements;
137
138
        replacements["PARAMS"] = cc.getBondedUtilities().addArgument(particleParams, "float4");
        cc.getBondedUtilities().addInteraction(atoms, cc.replaceStrings(CommonDrudeKernelSources::drudeParticleForce, replacements), force.getForceGroup());
139
    }
140
    int numPairs = force.getNumScreenedPairs();
141
142
143
144
    if (numPairs > 0) {
        // Create the screened interaction between dipole pairs.
        
        vector<vector<int> > atoms(numPairs, vector<int>(4));
145
        pairParams.initialize<mm_float2>(cc, numPairs, "drudePairParams");
146
147
148
149
        vector<mm_float2> paramVector(numPairs);
        for (int i = 0; i < numPairs; i++) {
            int drude1, drude2;
            double thole;
150
            force.getScreenedPairParameters(i, drude1, drude2, thole);
151
152
153
154
155
156
157
158
            int p2, p3, p4;
            double charge1, charge2, polarizability1, polarizability2, aniso12, aniso34;
            force.getParticleParameters(drude1, atoms[i][0], atoms[i][1], p2, p3, p4, charge1, polarizability1, aniso12, aniso34);
            force.getParticleParameters(drude2, atoms[i][2], atoms[i][3], p2, p3, p4, charge2, polarizability2, aniso12, aniso34);
            double screeningScale = thole/pow(polarizability1*polarizability2, 1.0/6.0);
            double energyScale = ONE_4PI_EPS0*charge1*charge2;
            paramVector[i] = mm_float2((float) screeningScale, (float) energyScale);
        }
peastman's avatar
peastman committed
159
        pairParams.upload(paramVector);
160
        map<string, string> replacements;
161
        replacements["APPLY_PERIODIC"] = (force.usesPeriodicBoundaryConditions() ? "1" : "0");
162
163
        replacements["PARAMS"] = cc.getBondedUtilities().addArgument(pairParams, "float2");
        cc.getBondedUtilities().addInteraction(atoms, cc.replaceStrings(CommonDrudeKernelSources::drudePairForce, replacements), force.getForceGroup());
164
    }
165
    cc.addForce(new CommonDrudeForceInfo(force));
166
167
}

168
double CommonCalcDrudeForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
169
170
171
    return 0.0;
}

172
173
void CommonCalcDrudeForceKernel::copyParametersToContext(ContextImpl& context, const DrudeForce& force) {
    if (cc.getContextIndex() != 0)
174
        return; // This is run entirely on one device
175
176
    
    // Set the particle parameters.
177
    
178
    ContextSelector selector(cc);
179
    int numParticles = force.getNumParticles();
180
    if (numParticles > 0) {
peastman's avatar
peastman committed
181
        if (!particleParams.isInitialized() || numParticles != particleParams.getSize())
182
183
184
185
186
            throw OpenMMException("updateParametersInContext: The number of Drude particles has changed");
        vector<mm_float4> paramVector(numParticles);
        for (int i = 0; i < numParticles; i++) {
            int p, p1, p2, p3, p4;
            double charge, polarizability, aniso12, aniso34;
187
            force.getParticleParameters(i, p, p1, p2, p3, p4, charge, polarizability, aniso12, aniso34);
188
189
190
            double a1 = (p2 == -1 ? 1 : aniso12);
            double a2 = (p3 == -1 || p4 == -1 ? 1 : aniso34);
            double a3 = 3-a1-a2;
191
192
193
            double k3 = ONE_4PI_EPS0*charge*charge/(polarizability*a3);
            double k1 = ONE_4PI_EPS0*charge*charge/(polarizability*a1) - k3;
            double k2 = ONE_4PI_EPS0*charge*charge/(polarizability*a2) - k3;
194
195
196
197
198
199
            if (p2 == -1)
                k1 = 0;
            if (p3 == -1 || p4 == -1)
                k2 = 0;
            paramVector[i] = mm_float4((float) k1, (float) k2, (float) k3, 0.0f);
        }
peastman's avatar
peastman committed
200
        particleParams.upload(paramVector);
201
202
203
204
    }
    
    // Set the pair parameters.
    
205
    int numPairs = force.getNumScreenedPairs();
206
    if (numPairs > 0) {
peastman's avatar
peastman committed
207
        if (!pairParams.isInitialized() || numPairs != pairParams.getSize())
208
209
210
211
212
            throw OpenMMException("updateParametersInContext: The number of screened pairs has changed");
        vector<mm_float2> paramVector(numPairs);
        for (int i = 0; i < numPairs; i++) {
            int drude1, drude2;
            double thole;
213
            force.getScreenedPairParameters(i, drude1, drude2, thole);
214
215
216
217
218
219
220
221
            int p, p1, p2, p3, p4;
            double charge1, charge2, polarizability1, polarizability2, aniso12, aniso34;
            force.getParticleParameters(drude1, p, p1, p2, p3, p4, charge1, polarizability1, aniso12, aniso34);
            force.getParticleParameters(drude2, p, p1, p2, p3, p4, charge2, polarizability2, aniso12, aniso34);
            double screeningScale = thole/pow(polarizability1*polarizability2, 1.0/6.0);
            double energyScale = ONE_4PI_EPS0*charge1*charge2;
            paramVector[i] = mm_float2((float) screeningScale, (float) energyScale);
        }
peastman's avatar
peastman committed
222
        pairParams.upload(paramVector);
223
    }
224
225
}

226
227
void CommonIntegrateDrudeLangevinStepKernel::initialize(const System& system, const DrudeLangevinIntegrator& integrator, const DrudeForce& force) {
    cc.initializeContexts();
228
    ContextSelector selector(cc);
229
    cc.getIntegrationUtilities().initRandomNumberGenerator((unsigned int) integrator.getRandomNumberSeed());
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
    
    // Identify particle pairs and ordinary particles.
    
    set<int> particles;
    vector<int> normalParticleVec;
    vector<mm_int2> pairParticleVec;
    for (int i = 0; i < system.getNumParticles(); i++)
        particles.insert(i);
    for (int i = 0; i < force.getNumParticles(); i++) {
        int p, p1, p2, p3, p4;
        double charge, polarizability, aniso12, aniso34;
        force.getParticleParameters(i, p, p1, p2, p3, p4, charge, polarizability, aniso12, aniso34);
        particles.erase(p);
        particles.erase(p1);
        pairParticleVec.push_back(mm_int2(p, p1));
    }
    normalParticleVec.insert(normalParticleVec.begin(), particles.begin(), particles.end());
247
248
    normalParticles.initialize<int>(cc, max((int) normalParticleVec.size(), 1), "drudeNormalParticles");
    pairParticles.initialize<mm_int2>(cc, max((int) pairParticleVec.size(), 1), "drudePairParticles");
249
    if (normalParticleVec.size() > 0)
peastman's avatar
peastman committed
250
        normalParticles.upload(normalParticleVec);
251
    if (pairParticleVec.size() > 0)
peastman's avatar
peastman committed
252
        pairParticles.upload(pairParticleVec);
253
254
255
256

    // Create kernels.
    
    map<string, string> defines;
257
258
259
260
    defines["NUM_ATOMS"] = cc.intToString(cc.getNumAtoms());
    defines["PADDED_NUM_ATOMS"] = cc.intToString(cc.getPaddedNumAtoms());
    defines["NUM_NORMAL_PARTICLES"] = cc.intToString(normalParticleVec.size());
    defines["NUM_PAIRS"] = cc.intToString(pairParticleVec.size());
261
    map<string, string> replacements;
262
263
264
265
    ComputeProgram program = cc.compileProgram(CommonDrudeKernelSources::drudeLangevin, defines);
    kernel1 = program->createKernel("integrateDrudeLangevinPart1");
    kernel2 = program->createKernel("integrateDrudeLangevinPart2");
    hardwallKernel = program->createKernel("applyHardWallConstraints");
266
267
268
    prevStepSize = -1.0;
}

269
void CommonIntegrateDrudeLangevinStepKernel::execute(ContextImpl& context, const DrudeLangevinIntegrator& integrator) {
270
    ContextSelector selector(cc);
271
272
    IntegrationUtilities& integration = cc.getIntegrationUtilities();
    int numAtoms = cc.getNumAtoms();
273
274
    if (!hasInitializedKernels) {
        hasInitializedKernels = true;
275
276
277
278
279
280
281
282
283
284
285
286
287
        kernel1->addArg(cc.getVelm());
        kernel1->addArg(cc.getLongForceBuffer());
        kernel1->addArg(integration.getPosDelta());
        kernel1->addArg(normalParticles);
        kernel1->addArg(pairParticles);
        kernel1->addArg(integration.getStepSize());
        for (int i = 0; i < 6; i++)
            kernel1->addArg();
        kernel1->addArg(integration.getRandom());
        kernel1->addArg();
        kernel2->addArg(cc.getPosq());
        if (cc.getUseMixedPrecision())
            kernel2->addArg(cc.getPosqCorrection());
288
        else
289
            kernel2->addArg(nullptr);
290
291
292
293
294
295
        kernel2->addArg(integration.getPosDelta());
        kernel2->addArg(cc.getVelm());
        kernel2->addArg(integration.getStepSize());
        hardwallKernel->addArg(cc.getPosq());
        if (cc.getUseMixedPrecision())
            hardwallKernel->addArg(cc.getPosqCorrection());
296
        else
297
            hardwallKernel->addArg(nullptr);
298
299
300
301
302
        hardwallKernel->addArg(cc.getVelm());
        hardwallKernel->addArg(pairParticles);
        hardwallKernel->addArg(integration.getStepSize());
        hardwallKernel->addArg();
        hardwallKernel->addArg();
303
304
305
306
307
308
    }
    
    // Compute integrator coefficients.
    
    double stepSize = integrator.getStepSize();
    double vscale = exp(-stepSize*integrator.getFriction());
309
    double fscale = (1-vscale)/integrator.getFriction()/(double) 0x100000000;
310
311
    double noisescale = sqrt(2*BOLTZ*integrator.getTemperature()*integrator.getFriction())*sqrt(0.5*(1-vscale*vscale)/integrator.getFriction());
    double vscaleDrude = exp(-stepSize*integrator.getDrudeFriction());
312
    double fscaleDrude = (1-vscaleDrude)/integrator.getDrudeFriction()/(double) 0x100000000;
313
    double noisescaleDrude = sqrt(2*BOLTZ*integrator.getDrudeTemperature()*integrator.getDrudeFriction())*sqrt(0.5*(1-vscaleDrude*vscaleDrude)/integrator.getDrudeFriction());
314
315
    double maxDrudeDistance = integrator.getMaxDrudeDistance();
    double hardwallscaleDrude = sqrt(BOLTZ*integrator.getDrudeTemperature());
316
    if (stepSize != prevStepSize) {
317
        if (cc.getUseDoublePrecision() || cc.getUseMixedPrecision()) {
318
319
320
321
322
323
324
325
326
            mm_double2 ss = mm_double2(0, stepSize);
            integration.getStepSize().upload(&ss);
        }
        else {
            mm_float2 ss = mm_float2(0, (float) stepSize);
            integration.getStepSize().upload(&ss);
        }
        prevStepSize = stepSize;
    }
327
328
329
330
331
332
333
334
335
    if (cc.getUseDoublePrecision() || cc.getUseMixedPrecision()) {
            kernel1->setArg(6, vscale);
            kernel1->setArg(7, fscale);
            kernel1->setArg(8, noisescale);
            kernel1->setArg(9, vscaleDrude);
            kernel1->setArg(10, fscaleDrude);
            kernel1->setArg(11, noisescaleDrude);
            hardwallKernel->setArg(5, maxDrudeDistance);
            hardwallKernel->setArg(6, hardwallscaleDrude);
336
337
    }
    else {
338
339
340
341
342
343
344
345
            kernel1->setArg(6, (float) vscale);
            kernel1->setArg(7, (float) fscale);
            kernel1->setArg(8, (float) noisescale);
            kernel1->setArg(9, (float) vscaleDrude);
            kernel1->setArg(10, (float) fscaleDrude);
            kernel1->setArg(11, (float) noisescaleDrude);
            hardwallKernel->setArg(5, (float) maxDrudeDistance);
            hardwallKernel->setArg(6, (float) hardwallscaleDrude);
346
347
348
349
    }

    // Call the first integration kernel.

350
351
    kernel1->setArg(13, integration.prepareRandomNumbers(normalParticles.getSize()+2*pairParticles.getSize()));
    kernel1->execute(numAtoms);
352
353
354
355
356
357
358

    // Apply constraints.

    integration.applyConstraints(integrator.getConstraintTolerance());

    // Call the second integration kernel.

359
    kernel2->execute(numAtoms);
360
361
362
    
    // Apply hard wall constraints.
    
363
    if (maxDrudeDistance > 0)
364
        hardwallKernel->execute(pairParticles.getSize());
365
366
367
368
    integration.computeVirtualSites();

    // Update the time and step count.

369
370
371
    cc.setTime(cc.getTime()+stepSize);
    cc.setStepCount(cc.getStepCount()+1);
    cc.reorderAtoms();
372
373
}

374
375
double CommonIntegrateDrudeLangevinStepKernel::computeKineticEnergy(ContextImpl& context, const DrudeLangevinIntegrator& integrator) {
    return cc.getIntegrationUtilities().computeKineticEnergy(0.5*integrator.getStepSize());
376
}
377

378
CommonIntegrateDrudeSCFStepKernel::~CommonIntegrateDrudeSCFStepKernel() {
379
380
}

381
382
void CommonIntegrateDrudeSCFStepKernel::initialize(const System& system, const DrudeSCFIntegrator& integrator, const DrudeForce& force) {
    cc.initializeContexts();
383
    ContextSelector selector(cc);
384
385
386
387
388
389
390
391
    int numDrude = force.getNumParticles();
    drudeParams.initialize<mm_float4>(cc, numDrude, "drudeParams");
    drudeIndices.initialize<int>(cc, numDrude, "drudeIndices");
    drudeParents.initialize<mm_int4>(cc, numDrude, "drudeParents");
    vector<mm_float4> paramVec(numDrude);
    vector<mm_int4> parentVec(numDrude);
    drudeIndexVec.resize(numDrude);
    for (int i = 0; i < numDrude; i++) {
392
393
394
        int p, p1, p2, p3, p4;
        double charge, polarizability, aniso12, aniso34;
        force.getParticleParameters(i, p, p1, p2, p3, p4, charge, polarizability, aniso12, aniso34);
395
396
397
398
399
400
401
402
403
        double a1 = (p2 == -1 ? 1 : aniso12);
        double a2 = (p3 == -1 || p4 == -1 ? 1 : aniso34);
        double a3 = 3-a1-a2;
        double k3 = ONE_4PI_EPS0*charge*charge/(polarizability*a3);
        double k1 = ONE_4PI_EPS0*charge*charge/(polarizability*a1) - k3;
        double k2 = ONE_4PI_EPS0*charge*charge/(polarizability*a2) - k3;
        paramVec[i] = mm_float4((float) k1, (float) k2, (float) k3, 0.0f);
        drudeIndexVec[i] = p;
        parentVec[i] = mm_int4(p1, p2, p3, p4);
404
    }
405
406
407
    drudeParams.upload(paramVec);
    drudeIndices.upload(drudeIndexVec);
    drudeParents.upload(parentVec);
408
409
410

    // Create the kernels.
    
411
412
413
    ComputeProgram program = cc.compileProgram(CommonKernelSources::verlet);
    kernel1 = program->createKernel("integrateVerletPart1");
    kernel2 = program->createKernel("integrateVerletPart2");
414
415
    program = cc.compileProgram(CommonDrudeKernelSources::drudeSCF);
    minimizeKernel = program->createKernel("minimizeDrudePositions");
416
417
418
    prevStepSize = -1.0;
}

419
void CommonIntegrateDrudeSCFStepKernel::execute(ContextImpl& context, const DrudeSCFIntegrator& integrator) {
420
    ContextSelector selector(cc);
421
422
    IntegrationUtilities& integration = cc.getIntegrationUtilities();
    int numAtoms = cc.getNumAtoms();
423
424
425
    double dt = integrator.getStepSize();
    if (!hasInitializedKernels) {
        hasInitializedKernels = true;
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
        kernel1->addArg(numAtoms);
        kernel1->addArg(cc.getPaddedNumAtoms());
        kernel1->addArg(cc.getIntegrationUtilities().getStepSize());
        kernel1->addArg(cc.getPosq());
        kernel1->addArg(cc.getVelm());
        kernel1->addArg(cc.getLongForceBuffer());
        kernel1->addArg(integration.getPosDelta());
        if (cc.getUseMixedPrecision())
            kernel1->addArg(cc.getPosqCorrection());
        kernel2->addArg(numAtoms);
        kernel2->addArg(cc.getIntegrationUtilities().getStepSize());
        kernel2->addArg(cc.getPosq());
        kernel2->addArg(cc.getVelm());
        kernel2->addArg(integration.getPosDelta());
        if (cc.getUseMixedPrecision())
            kernel2->addArg(cc.getPosqCorrection());
442
443
444
445
446
447
448
449
        minimizeKernel->addArg((int) drudeParams.getSize());
        minimizeKernel->addArg(cc.getPaddedNumAtoms());
        minimizeKernel->addArg();
        minimizeKernel->addArg(cc.getPosq());
        minimizeKernel->addArg(cc.getLongForceBuffer());
        minimizeKernel->addArg(drudeParams);
        minimizeKernel->addArg(drudeIndices);
        minimizeKernel->addArg(drudeParents);
450
451
    }
    if (dt != prevStepSize) {
452
        if (cc.getUseDoublePrecision() || cc.getUseMixedPrecision()) {
453
454
            vector<mm_double2> stepSizeVec(1);
            stepSizeVec[0] = mm_double2(dt, dt);
455
            cc.getIntegrationUtilities().getStepSize().upload(stepSizeVec);
456
457
458
        }
        else {
            vector<mm_float2> stepSizeVec(1);
459
460
            stepSizeVec[0] = mm_float2((float) dt, (float) dt);
            cc.getIntegrationUtilities().getStepSize().upload(stepSizeVec);
461
462
463
464
465
466
        }
        prevStepSize = dt;
    }

    // Call the first integration kernel.

467
    kernel1->execute(numAtoms);
468
469
470
471
472
473
474

    // Apply constraints.

    integration.applyConstraints(integrator.getConstraintTolerance());

    // Call the second integration kernel.

475
    kernel2->execute(numAtoms);
476
477
478
479
480
481
482
483

    // Update the positions of virtual sites and Drude particles.

    integration.computeVirtualSites();
    minimize(context, integrator.getMinimizationErrorTolerance());

    // Update the time and step count.

484
485
486
    cc.setTime(cc.getTime()+dt);
    cc.setStepCount(cc.getStepCount()+1);
    cc.reorderAtoms();
487
488
489
490
    
    // Reduce UI lag.
    
#ifdef WIN32
491
    cc.flushQueue();
492
493
494
#endif
}

495
496
double CommonIntegrateDrudeSCFStepKernel::computeKineticEnergy(ContextImpl& context, const DrudeSCFIntegrator& integrator) {
    return cc.getIntegrationUtilities().computeKineticEnergy(0.5*integrator.getStepSize());
497
498
}

499
void CommonIntegrateDrudeSCFStepKernel::minimize(ContextImpl& context, double tolerance) {
500
501
502
503
504
505
506
507
508
509
510
511
512
513
    minimizeKernel->setArg(2, (float) tolerance);
    long long* forces = (long long*) cc.getPinnedBuffer();
    double scale = 1/(double) 0x100000000;
    double lastForce = 0;
    int numDrude = drudeParams.getSize();
    int paddedNumAtoms = cc.getPaddedNumAtoms();
    for (int iteration = 0; iteration < 50; iteration++) {
        context.calcForcesAndEnergy(true, false, context.getIntegrator().getIntegrationForceGroups());
        minimizeKernel->execute(drudeParams.getSize());
        cc.getLongForceBuffer().download(forces);
        double totalForce = 0;
        for (int i : drudeIndexVec) {
            Vec3 f(scale*forces[i], scale*forces[i+paddedNumAtoms], scale*forces[i+paddedNumAtoms*2]);
            totalForce += f.dot(f);
514
        }
515
516
517
        if (sqrt(totalForce/(3*numDrude)) < tolerance || (iteration > 0 && totalForce > 0.9*lastForce)) 
            break;
        lastForce = totalForce;
518
    }
519
}