CommonDrudeKernels.cpp 23.7 KB
Newer Older
1
2
3
/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
Evan Pretti's avatar
Evan Pretti committed
4
5
 * This is part of the OpenMM molecular simulation toolkit.                   *
 * See https://openmm.org/development.                                        *
6
 *                                                                            *
7
 * Portions copyright (c) 2013-2024 Stanford University and the Authors.      *
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
 * Authors: Peter Eastman                                                     *
 * Contributors:                                                              *
 *                                                                            *
 * Permission is hereby granted, free of charge, to any person obtaining a    *
 * copy of this software and associated documentation files (the "Software"), *
 * to deal in the Software without restriction, including without limitation  *
 * the rights to use, copy, modify, merge, publish, distribute, sublicense,   *
 * and/or sell copies of the Software, and to permit persons to whom the      *
 * Software is furnished to do so, subject to the following conditions:       *
 *                                                                            *
 * The above copyright notice and this permission notice shall be included in *
 * all copies or substantial portions of the Software.                        *
 *                                                                            *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,   *
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL    *
 * THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,    *
 * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR      *
 * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE  *
 * USE OR OTHER DEALINGS IN THE SOFTWARE.                                     *
 * -------------------------------------------------------------------------- */

30
31
#include "CommonDrudeKernels.h"
#include "CommonDrudeKernelSources.h"
32
#include "openmm/internal/ContextImpl.h"
33
34
#include "openmm/common/BondedUtilities.h"
#include "openmm/common/ComputeForceInfo.h"
35
#include "openmm/common/ContextSelector.h"
36
37
#include "openmm/common/IntegrationUtilities.h"
#include "CommonKernelSources.h"
38
#include "SimTKOpenMMRealType.h"
39
40
41
42
43
#include <set>

using namespace OpenMM;
using namespace std;

44
class CommonDrudeForceInfo : public ComputeForceInfo {
45
public:
46
    CommonDrudeForceInfo(const DrudeForce& force) : force(force) {
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
    }
    int getNumParticleGroups() {
        return force.getNumParticles()+force.getNumScreenedPairs();
    }
    void getParticlesInGroup(int index, vector<int>& particles) {
        particles.clear();
        if (index < force.getNumParticles()) {
            int p, p1, p2, p3, p4;
            double charge, polarizability, aniso12, aniso34;
            force.getParticleParameters(index, p, p1, p2, p3, p4, charge, polarizability, aniso12, aniso34);
            particles.push_back(p);
            particles.push_back(p1);
            if (p2 != -1)
                particles.push_back(p2);
            if (p3 != -1)
                particles.push_back(p3);
            if (p4 != -1)
                particles.push_back(p4);
        }
        else {
            int drude1, drude2;
            double thole;
            force.getScreenedPairParameters(index-force.getNumParticles(), drude1, drude2, thole);
            int p, p1, p2, p3, p4;
            double charge, polarizability, aniso12, aniso34;
            force.getParticleParameters(drude1, p, p1, p2, p3, p4, charge, polarizability, aniso12, aniso34);
            particles.push_back(p);
            particles.push_back(p1);
            force.getParticleParameters(drude2, p, p1, p2, p3, p4, charge, polarizability, aniso12, aniso34);
            particles.push_back(p);
            particles.push_back(p1);
        }
    }
    bool areGroupsIdentical(int group1, int group2) {
        if (group1 < force.getNumParticles() && group2 < force.getNumParticles()) {
            int p, p1, p2, p3, p4;
            double charge1, polarizability1, aniso12_1, aniso34_1;
            double charge2, polarizability2, aniso12_2, aniso34_2;
            force.getParticleParameters(group1, p, p1, p2, p3, p4, charge1, polarizability1, aniso12_1, aniso34_1);
            force.getParticleParameters(group2, p, p1, p2, p3, p4, charge2, polarizability2, aniso12_2, aniso34_2);
            return (charge1 == charge2 && polarizability1 == polarizability2 && aniso12_1 == aniso12_2 && aniso34_1 == aniso34_2);
        }
        if (group1 >= force.getNumParticles() && group2 >= force.getNumParticles()) {
            int drude1, drude2;
            double thole1, thole2;
            force.getScreenedPairParameters(group1-force.getNumParticles(), drude1, drude2, thole1);
            force.getScreenedPairParameters(group1-force.getNumParticles(), drude1, drude2, thole2);
            return (thole1 == thole2);
        }
        return false;
    }
private:
    const DrudeForce& force;
};

102
103
void CommonCalcDrudeForceKernel::initialize(const System& system, const DrudeForce& force) {
    if (cc.getContextIndex() != 0)
104
        return; // This is run entirely on one device
105
    ContextSelector selector(cc);
106
    int numParticles = force.getNumParticles();
107
108
109
110
    if (numParticles > 0) {
        // Create the harmonic interaction .
        
        vector<vector<int> > atoms(numParticles, vector<int>(5));
111
        particleParams.initialize<mm_float4>(cc, numParticles, "drudeParticleParams");
112
113
114
        vector<mm_float4> paramVector(numParticles);
        for (int i = 0; i < numParticles; i++) {
            double charge, polarizability, aniso12, aniso34;
115
            force.getParticleParameters(i, atoms[i][0], atoms[i][1], atoms[i][2], atoms[i][3], atoms[i][4], charge, polarizability, aniso12, aniso34);
116
117
118
            double a1 = (atoms[i][2] == -1 ? 1 : aniso12);
            double a2 = (atoms[i][3] == -1 || atoms[i][4] == -1 ? 1 : aniso34);
            double a3 = 3-a1-a2;
119
120
121
            double k3 = ONE_4PI_EPS0*charge*charge/(polarizability*a3);
            double k1 = ONE_4PI_EPS0*charge*charge/(polarizability*a1) - k3;
            double k2 = ONE_4PI_EPS0*charge*charge/(polarizability*a2) - k3;
122
            if (atoms[i][2] == -1) {
123
                atoms[i][2] = atoms[i][0];
124
125
126
                k1 = 0;
            }
            if (atoms[i][3] == -1 || atoms[i][4] == -1) {
127
128
                atoms[i][3] = atoms[i][0];
                atoms[i][4] = atoms[i][0];
129
130
131
132
                k2 = 0;
            }
            paramVector[i] = mm_float4((float) k1, (float) k2, (float) k3, 0.0f);
        }
peastman's avatar
peastman committed
133
        particleParams.upload(paramVector);
134
        map<string, string> replacements;
135
136
        replacements["PARAMS"] = cc.getBondedUtilities().addArgument(particleParams, "float4");
        cc.getBondedUtilities().addInteraction(atoms, cc.replaceStrings(CommonDrudeKernelSources::drudeParticleForce, replacements), force.getForceGroup());
137
    }
138
    int numPairs = force.getNumScreenedPairs();
139
140
141
142
    if (numPairs > 0) {
        // Create the screened interaction between dipole pairs.
        
        vector<vector<int> > atoms(numPairs, vector<int>(4));
143
        pairParams.initialize<mm_float2>(cc, numPairs, "drudePairParams");
144
145
146
147
        vector<mm_float2> paramVector(numPairs);
        for (int i = 0; i < numPairs; i++) {
            int drude1, drude2;
            double thole;
148
            force.getScreenedPairParameters(i, drude1, drude2, thole);
149
150
151
152
153
154
155
156
            int p2, p3, p4;
            double charge1, charge2, polarizability1, polarizability2, aniso12, aniso34;
            force.getParticleParameters(drude1, atoms[i][0], atoms[i][1], p2, p3, p4, charge1, polarizability1, aniso12, aniso34);
            force.getParticleParameters(drude2, atoms[i][2], atoms[i][3], p2, p3, p4, charge2, polarizability2, aniso12, aniso34);
            double screeningScale = thole/pow(polarizability1*polarizability2, 1.0/6.0);
            double energyScale = ONE_4PI_EPS0*charge1*charge2;
            paramVector[i] = mm_float2((float) screeningScale, (float) energyScale);
        }
peastman's avatar
peastman committed
157
        pairParams.upload(paramVector);
158
        map<string, string> replacements;
159
        replacements["APPLY_PERIODIC"] = (force.usesPeriodicBoundaryConditions() ? "1" : "0");
160
161
        replacements["PARAMS"] = cc.getBondedUtilities().addArgument(pairParams, "float2");
        cc.getBondedUtilities().addInteraction(atoms, cc.replaceStrings(CommonDrudeKernelSources::drudePairForce, replacements), force.getForceGroup());
162
    }
163
    cc.addForce(new CommonDrudeForceInfo(force));
164
165
}

166
double CommonCalcDrudeForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
167
168
169
    return 0.0;
}

170
171
void CommonCalcDrudeForceKernel::copyParametersToContext(ContextImpl& context, const DrudeForce& force) {
    if (cc.getContextIndex() != 0)
172
        return; // This is run entirely on one device
173
174
    
    // Set the particle parameters.
175
    
176
    ContextSelector selector(cc);
177
    int numParticles = force.getNumParticles();
178
    if (numParticles > 0) {
peastman's avatar
peastman committed
179
        if (!particleParams.isInitialized() || numParticles != particleParams.getSize())
180
181
182
183
184
            throw OpenMMException("updateParametersInContext: The number of Drude particles has changed");
        vector<mm_float4> paramVector(numParticles);
        for (int i = 0; i < numParticles; i++) {
            int p, p1, p2, p3, p4;
            double charge, polarizability, aniso12, aniso34;
185
            force.getParticleParameters(i, p, p1, p2, p3, p4, charge, polarizability, aniso12, aniso34);
186
187
188
            double a1 = (p2 == -1 ? 1 : aniso12);
            double a2 = (p3 == -1 || p4 == -1 ? 1 : aniso34);
            double a3 = 3-a1-a2;
189
190
191
            double k3 = ONE_4PI_EPS0*charge*charge/(polarizability*a3);
            double k1 = ONE_4PI_EPS0*charge*charge/(polarizability*a1) - k3;
            double k2 = ONE_4PI_EPS0*charge*charge/(polarizability*a2) - k3;
192
193
194
195
196
197
            if (p2 == -1)
                k1 = 0;
            if (p3 == -1 || p4 == -1)
                k2 = 0;
            paramVector[i] = mm_float4((float) k1, (float) k2, (float) k3, 0.0f);
        }
peastman's avatar
peastman committed
198
        particleParams.upload(paramVector);
199
200
201
202
    }
    
    // Set the pair parameters.
    
203
    int numPairs = force.getNumScreenedPairs();
204
    if (numPairs > 0) {
peastman's avatar
peastman committed
205
        if (!pairParams.isInitialized() || numPairs != pairParams.getSize())
206
207
208
209
210
            throw OpenMMException("updateParametersInContext: The number of screened pairs has changed");
        vector<mm_float2> paramVector(numPairs);
        for (int i = 0; i < numPairs; i++) {
            int drude1, drude2;
            double thole;
211
            force.getScreenedPairParameters(i, drude1, drude2, thole);
212
213
214
215
216
217
218
219
            int p, p1, p2, p3, p4;
            double charge1, charge2, polarizability1, polarizability2, aniso12, aniso34;
            force.getParticleParameters(drude1, p, p1, p2, p3, p4, charge1, polarizability1, aniso12, aniso34);
            force.getParticleParameters(drude2, p, p1, p2, p3, p4, charge2, polarizability2, aniso12, aniso34);
            double screeningScale = thole/pow(polarizability1*polarizability2, 1.0/6.0);
            double energyScale = ONE_4PI_EPS0*charge1*charge2;
            paramVector[i] = mm_float2((float) screeningScale, (float) energyScale);
        }
peastman's avatar
peastman committed
220
        pairParams.upload(paramVector);
221
    }
222
223
}

224
225
void CommonIntegrateDrudeLangevinStepKernel::initialize(const System& system, const DrudeLangevinIntegrator& integrator, const DrudeForce& force) {
    cc.initializeContexts();
226
    ContextSelector selector(cc);
227
    cc.getIntegrationUtilities().initRandomNumberGenerator((unsigned int) integrator.getRandomNumberSeed());
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
    
    // Identify particle pairs and ordinary particles.
    
    set<int> particles;
    vector<int> normalParticleVec;
    vector<mm_int2> pairParticleVec;
    for (int i = 0; i < system.getNumParticles(); i++)
        particles.insert(i);
    for (int i = 0; i < force.getNumParticles(); i++) {
        int p, p1, p2, p3, p4;
        double charge, polarizability, aniso12, aniso34;
        force.getParticleParameters(i, p, p1, p2, p3, p4, charge, polarizability, aniso12, aniso34);
        particles.erase(p);
        particles.erase(p1);
        pairParticleVec.push_back(mm_int2(p, p1));
    }
    normalParticleVec.insert(normalParticleVec.begin(), particles.begin(), particles.end());
245
246
    normalParticles.initialize<int>(cc, max((int) normalParticleVec.size(), 1), "drudeNormalParticles");
    pairParticles.initialize<mm_int2>(cc, max((int) pairParticleVec.size(), 1), "drudePairParticles");
247
    if (normalParticleVec.size() > 0)
peastman's avatar
peastman committed
248
        normalParticles.upload(normalParticleVec);
249
    if (pairParticleVec.size() > 0)
peastman's avatar
peastman committed
250
        pairParticles.upload(pairParticleVec);
251
252
253
254

    // Create kernels.
    
    map<string, string> defines;
255
256
257
258
    defines["NUM_ATOMS"] = cc.intToString(cc.getNumAtoms());
    defines["PADDED_NUM_ATOMS"] = cc.intToString(cc.getPaddedNumAtoms());
    defines["NUM_NORMAL_PARTICLES"] = cc.intToString(normalParticleVec.size());
    defines["NUM_PAIRS"] = cc.intToString(pairParticleVec.size());
259
    map<string, string> replacements;
260
261
262
263
    ComputeProgram program = cc.compileProgram(CommonDrudeKernelSources::drudeLangevin, defines);
    kernel1 = program->createKernel("integrateDrudeLangevinPart1");
    kernel2 = program->createKernel("integrateDrudeLangevinPart2");
    hardwallKernel = program->createKernel("applyHardWallConstraints");
264
265
266
    prevStepSize = -1.0;
}

267
void CommonIntegrateDrudeLangevinStepKernel::execute(ContextImpl& context, const DrudeLangevinIntegrator& integrator) {
268
    ContextSelector selector(cc);
269
270
    IntegrationUtilities& integration = cc.getIntegrationUtilities();
    int numAtoms = cc.getNumAtoms();
271
272
    if (!hasInitializedKernels) {
        hasInitializedKernels = true;
273
274
275
276
277
278
279
280
281
282
283
284
285
        kernel1->addArg(cc.getVelm());
        kernel1->addArg(cc.getLongForceBuffer());
        kernel1->addArg(integration.getPosDelta());
        kernel1->addArg(normalParticles);
        kernel1->addArg(pairParticles);
        kernel1->addArg(integration.getStepSize());
        for (int i = 0; i < 6; i++)
            kernel1->addArg();
        kernel1->addArg(integration.getRandom());
        kernel1->addArg();
        kernel2->addArg(cc.getPosq());
        if (cc.getUseMixedPrecision())
            kernel2->addArg(cc.getPosqCorrection());
286
        else
287
            kernel2->addArg(nullptr);
288
289
290
291
292
293
        kernel2->addArg(integration.getPosDelta());
        kernel2->addArg(cc.getVelm());
        kernel2->addArg(integration.getStepSize());
        hardwallKernel->addArg(cc.getPosq());
        if (cc.getUseMixedPrecision())
            hardwallKernel->addArg(cc.getPosqCorrection());
294
        else
295
            hardwallKernel->addArg(nullptr);
296
297
298
299
300
        hardwallKernel->addArg(cc.getVelm());
        hardwallKernel->addArg(pairParticles);
        hardwallKernel->addArg(integration.getStepSize());
        hardwallKernel->addArg();
        hardwallKernel->addArg();
301
302
303
304
305
306
    }
    
    // Compute integrator coefficients.
    
    double stepSize = integrator.getStepSize();
    double vscale = exp(-stepSize*integrator.getFriction());
307
    double fscale = (1-vscale)/integrator.getFriction()/(double) 0x100000000;
308
309
    double noisescale = sqrt(2*BOLTZ*integrator.getTemperature()*integrator.getFriction())*sqrt(0.5*(1-vscale*vscale)/integrator.getFriction());
    double vscaleDrude = exp(-stepSize*integrator.getDrudeFriction());
310
    double fscaleDrude = (1-vscaleDrude)/integrator.getDrudeFriction()/(double) 0x100000000;
311
    double noisescaleDrude = sqrt(2*BOLTZ*integrator.getDrudeTemperature()*integrator.getDrudeFriction())*sqrt(0.5*(1-vscaleDrude*vscaleDrude)/integrator.getDrudeFriction());
312
313
    double maxDrudeDistance = integrator.getMaxDrudeDistance();
    double hardwallscaleDrude = sqrt(BOLTZ*integrator.getDrudeTemperature());
314
    if (stepSize != prevStepSize) {
315
        if (cc.getUseDoublePrecision() || cc.getUseMixedPrecision()) {
316
317
318
319
320
321
322
323
324
            mm_double2 ss = mm_double2(0, stepSize);
            integration.getStepSize().upload(&ss);
        }
        else {
            mm_float2 ss = mm_float2(0, (float) stepSize);
            integration.getStepSize().upload(&ss);
        }
        prevStepSize = stepSize;
    }
325
326
327
328
329
330
331
332
333
    if (cc.getUseDoublePrecision() || cc.getUseMixedPrecision()) {
            kernel1->setArg(6, vscale);
            kernel1->setArg(7, fscale);
            kernel1->setArg(8, noisescale);
            kernel1->setArg(9, vscaleDrude);
            kernel1->setArg(10, fscaleDrude);
            kernel1->setArg(11, noisescaleDrude);
            hardwallKernel->setArg(5, maxDrudeDistance);
            hardwallKernel->setArg(6, hardwallscaleDrude);
334
335
    }
    else {
336
337
338
339
340
341
342
343
            kernel1->setArg(6, (float) vscale);
            kernel1->setArg(7, (float) fscale);
            kernel1->setArg(8, (float) noisescale);
            kernel1->setArg(9, (float) vscaleDrude);
            kernel1->setArg(10, (float) fscaleDrude);
            kernel1->setArg(11, (float) noisescaleDrude);
            hardwallKernel->setArg(5, (float) maxDrudeDistance);
            hardwallKernel->setArg(6, (float) hardwallscaleDrude);
344
345
346
347
    }

    // Call the first integration kernel.

348
349
    kernel1->setArg(13, integration.prepareRandomNumbers(normalParticles.getSize()+2*pairParticles.getSize()));
    kernel1->execute(numAtoms);
350
351
352
353
354
355
356

    // Apply constraints.

    integration.applyConstraints(integrator.getConstraintTolerance());

    // Call the second integration kernel.

357
    kernel2->execute(numAtoms);
358
359
360
    
    // Apply hard wall constraints.
    
361
    if (maxDrudeDistance > 0)
362
        hardwallKernel->execute(pairParticles.getSize());
363
364
365
366
    integration.computeVirtualSites();

    // Update the time and step count.

367
368
369
    cc.setTime(cc.getTime()+stepSize);
    cc.setStepCount(cc.getStepCount()+1);
    cc.reorderAtoms();
370
371
}

372
373
double CommonIntegrateDrudeLangevinStepKernel::computeKineticEnergy(ContextImpl& context, const DrudeLangevinIntegrator& integrator) {
    return cc.getIntegrationUtilities().computeKineticEnergy(0.5*integrator.getStepSize());
374
}
375

376
CommonIntegrateDrudeSCFStepKernel::~CommonIntegrateDrudeSCFStepKernel() {
377
378
}

379
380
void CommonIntegrateDrudeSCFStepKernel::initialize(const System& system, const DrudeSCFIntegrator& integrator, const DrudeForce& force) {
    cc.initializeContexts();
381
    ContextSelector selector(cc);
382
383
384
385
386
387
388
389
    int numDrude = force.getNumParticles();
    drudeParams.initialize<mm_float4>(cc, numDrude, "drudeParams");
    drudeIndices.initialize<int>(cc, numDrude, "drudeIndices");
    drudeParents.initialize<mm_int4>(cc, numDrude, "drudeParents");
    vector<mm_float4> paramVec(numDrude);
    vector<mm_int4> parentVec(numDrude);
    drudeIndexVec.resize(numDrude);
    for (int i = 0; i < numDrude; i++) {
390
391
392
        int p, p1, p2, p3, p4;
        double charge, polarizability, aniso12, aniso34;
        force.getParticleParameters(i, p, p1, p2, p3, p4, charge, polarizability, aniso12, aniso34);
393
394
395
396
397
398
399
400
401
        double a1 = (p2 == -1 ? 1 : aniso12);
        double a2 = (p3 == -1 || p4 == -1 ? 1 : aniso34);
        double a3 = 3-a1-a2;
        double k3 = ONE_4PI_EPS0*charge*charge/(polarizability*a3);
        double k1 = ONE_4PI_EPS0*charge*charge/(polarizability*a1) - k3;
        double k2 = ONE_4PI_EPS0*charge*charge/(polarizability*a2) - k3;
        paramVec[i] = mm_float4((float) k1, (float) k2, (float) k3, 0.0f);
        drudeIndexVec[i] = p;
        parentVec[i] = mm_int4(p1, p2, p3, p4);
402
    }
403
404
405
    drudeParams.upload(paramVec);
    drudeIndices.upload(drudeIndexVec);
    drudeParents.upload(parentVec);
406
407
408

    // Create the kernels.
    
409
410
411
    ComputeProgram program = cc.compileProgram(CommonKernelSources::verlet);
    kernel1 = program->createKernel("integrateVerletPart1");
    kernel2 = program->createKernel("integrateVerletPart2");
412
413
    program = cc.compileProgram(CommonDrudeKernelSources::drudeSCF);
    minimizeKernel = program->createKernel("minimizeDrudePositions");
414
415
416
    prevStepSize = -1.0;
}

417
void CommonIntegrateDrudeSCFStepKernel::execute(ContextImpl& context, const DrudeSCFIntegrator& integrator) {
418
    ContextSelector selector(cc);
419
420
    IntegrationUtilities& integration = cc.getIntegrationUtilities();
    int numAtoms = cc.getNumAtoms();
421
422
423
    double dt = integrator.getStepSize();
    if (!hasInitializedKernels) {
        hasInitializedKernels = true;
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
        kernel1->addArg(numAtoms);
        kernel1->addArg(cc.getPaddedNumAtoms());
        kernel1->addArg(cc.getIntegrationUtilities().getStepSize());
        kernel1->addArg(cc.getPosq());
        kernel1->addArg(cc.getVelm());
        kernel1->addArg(cc.getLongForceBuffer());
        kernel1->addArg(integration.getPosDelta());
        if (cc.getUseMixedPrecision())
            kernel1->addArg(cc.getPosqCorrection());
        kernel2->addArg(numAtoms);
        kernel2->addArg(cc.getIntegrationUtilities().getStepSize());
        kernel2->addArg(cc.getPosq());
        kernel2->addArg(cc.getVelm());
        kernel2->addArg(integration.getPosDelta());
        if (cc.getUseMixedPrecision())
            kernel2->addArg(cc.getPosqCorrection());
440
441
442
443
444
445
446
447
        minimizeKernel->addArg((int) drudeParams.getSize());
        minimizeKernel->addArg(cc.getPaddedNumAtoms());
        minimizeKernel->addArg();
        minimizeKernel->addArg(cc.getPosq());
        minimizeKernel->addArg(cc.getLongForceBuffer());
        minimizeKernel->addArg(drudeParams);
        minimizeKernel->addArg(drudeIndices);
        minimizeKernel->addArg(drudeParents);
448
449
    }
    if (dt != prevStepSize) {
450
        if (cc.getUseDoublePrecision() || cc.getUseMixedPrecision()) {
451
452
            vector<mm_double2> stepSizeVec(1);
            stepSizeVec[0] = mm_double2(dt, dt);
453
            cc.getIntegrationUtilities().getStepSize().upload(stepSizeVec);
454
455
456
        }
        else {
            vector<mm_float2> stepSizeVec(1);
457
458
            stepSizeVec[0] = mm_float2((float) dt, (float) dt);
            cc.getIntegrationUtilities().getStepSize().upload(stepSizeVec);
459
460
461
462
463
464
        }
        prevStepSize = dt;
    }

    // Call the first integration kernel.

465
    kernel1->execute(numAtoms);
466
467
468
469
470
471
472

    // Apply constraints.

    integration.applyConstraints(integrator.getConstraintTolerance());

    // Call the second integration kernel.

473
    kernel2->execute(numAtoms);
474
475
476
477
478
479
480
481

    // Update the positions of virtual sites and Drude particles.

    integration.computeVirtualSites();
    minimize(context, integrator.getMinimizationErrorTolerance());

    // Update the time and step count.

482
483
484
    cc.setTime(cc.getTime()+dt);
    cc.setStepCount(cc.getStepCount()+1);
    cc.reorderAtoms();
485
486
487
488
    
    // Reduce UI lag.
    
#ifdef WIN32
489
    cc.flushQueue();
490
491
492
#endif
}

493
494
double CommonIntegrateDrudeSCFStepKernel::computeKineticEnergy(ContextImpl& context, const DrudeSCFIntegrator& integrator) {
    return cc.getIntegrationUtilities().computeKineticEnergy(0.5*integrator.getStepSize());
495
496
}

497
void CommonIntegrateDrudeSCFStepKernel::minimize(ContextImpl& context, double tolerance) {
498
499
500
501
502
503
504
    minimizeKernel->setArg(2, (float) tolerance);
    long long* forces = (long long*) cc.getPinnedBuffer();
    double scale = 1/(double) 0x100000000;
    double lastForce = 0;
    int numDrude = drudeParams.getSize();
    int paddedNumAtoms = cc.getPaddedNumAtoms();
    for (int iteration = 0; iteration < 50; iteration++) {
505
506
507
508
        {
            ContextDeselector deselector(cc);
            context.calcForcesAndEnergy(true, false, context.getIntegrator().getIntegrationForceGroups());
        }
509
510
511
512
513
514
        minimizeKernel->execute(drudeParams.getSize());
        cc.getLongForceBuffer().download(forces);
        double totalForce = 0;
        for (int i : drudeIndexVec) {
            Vec3 f(scale*forces[i], scale*forces[i+paddedNumAtoms], scale*forces[i+paddedNumAtoms*2]);
            totalForce += f.dot(f);
515
        }
516
517
518
        if (sqrt(totalForce/(3*numDrude)) < tolerance || (iteration > 0 && totalForce > 0.9*lastForce)) 
            break;
        lastForce = totalForce;
519
    }
520
}