CommonDrudeKernels.cpp 25.1 KB
Newer Older
1
2
3
4
5
6
7
8
/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
 * This is part of the OpenMM molecular simulation toolkit originating from   *
 * Simbios, the NIH National Center for Physics-Based Simulation of           *
 * Biological Structures at Stanford, funded under the NIH Roadmap for        *
 * Medical Research, grant U54 GM072970. See https://simtk.org.               *
 *                                                                            *
9
 * Portions copyright (c) 2013-2020 Stanford University and the Authors.      *
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
 * Authors: Peter Eastman                                                     *
 * Contributors:                                                              *
 *                                                                            *
 * Permission is hereby granted, free of charge, to any person obtaining a    *
 * copy of this software and associated documentation files (the "Software"), *
 * to deal in the Software without restriction, including without limitation  *
 * the rights to use, copy, modify, merge, publish, distribute, sublicense,   *
 * and/or sell copies of the Software, and to permit persons to whom the      *
 * Software is furnished to do so, subject to the following conditions:       *
 *                                                                            *
 * The above copyright notice and this permission notice shall be included in *
 * all copies or substantial portions of the Software.                        *
 *                                                                            *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,   *
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL    *
 * THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,    *
 * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR      *
 * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE  *
 * USE OR OTHER DEALINGS IN THE SOFTWARE.                                     *
 * -------------------------------------------------------------------------- */

32
33
#include "CommonDrudeKernels.h"
#include "CommonDrudeKernelSources.h"
34
#include "openmm/internal/ContextImpl.h"
35
36
#include "openmm/common/BondedUtilities.h"
#include "openmm/common/ComputeForceInfo.h"
37
#include "openmm/common/ContextSelector.h"
38
39
#include "openmm/common/IntegrationUtilities.h"
#include "CommonKernelSources.h"
40
#include "SimTKOpenMMRealType.h"
41
42
43
44
45
#include <set>

using namespace OpenMM;
using namespace std;

46
class CommonDrudeForceInfo : public ComputeForceInfo {
47
public:
48
    CommonDrudeForceInfo(const DrudeForce& force) : force(force) {
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
    }
    int getNumParticleGroups() {
        return force.getNumParticles()+force.getNumScreenedPairs();
    }
    void getParticlesInGroup(int index, vector<int>& particles) {
        particles.clear();
        if (index < force.getNumParticles()) {
            int p, p1, p2, p3, p4;
            double charge, polarizability, aniso12, aniso34;
            force.getParticleParameters(index, p, p1, p2, p3, p4, charge, polarizability, aniso12, aniso34);
            particles.push_back(p);
            particles.push_back(p1);
            if (p2 != -1)
                particles.push_back(p2);
            if (p3 != -1)
                particles.push_back(p3);
            if (p4 != -1)
                particles.push_back(p4);
        }
        else {
            int drude1, drude2;
            double thole;
            force.getScreenedPairParameters(index-force.getNumParticles(), drude1, drude2, thole);
            int p, p1, p2, p3, p4;
            double charge, polarizability, aniso12, aniso34;
            force.getParticleParameters(drude1, p, p1, p2, p3, p4, charge, polarizability, aniso12, aniso34);
            particles.push_back(p);
            particles.push_back(p1);
            force.getParticleParameters(drude2, p, p1, p2, p3, p4, charge, polarizability, aniso12, aniso34);
            particles.push_back(p);
            particles.push_back(p1);
        }
    }
    bool areGroupsIdentical(int group1, int group2) {
        if (group1 < force.getNumParticles() && group2 < force.getNumParticles()) {
            int p, p1, p2, p3, p4;
            double charge1, polarizability1, aniso12_1, aniso34_1;
            double charge2, polarizability2, aniso12_2, aniso34_2;
            force.getParticleParameters(group1, p, p1, p2, p3, p4, charge1, polarizability1, aniso12_1, aniso34_1);
            force.getParticleParameters(group2, p, p1, p2, p3, p4, charge2, polarizability2, aniso12_2, aniso34_2);
            return (charge1 == charge2 && polarizability1 == polarizability2 && aniso12_1 == aniso12_2 && aniso34_1 == aniso34_2);
        }
        if (group1 >= force.getNumParticles() && group2 >= force.getNumParticles()) {
            int drude1, drude2;
            double thole1, thole2;
            force.getScreenedPairParameters(group1-force.getNumParticles(), drude1, drude2, thole1);
            force.getScreenedPairParameters(group1-force.getNumParticles(), drude1, drude2, thole2);
            return (thole1 == thole2);
        }
        return false;
    }
private:
    const DrudeForce& force;
};

104
105
void CommonCalcDrudeForceKernel::initialize(const System& system, const DrudeForce& force) {
    if (cc.getContextIndex() != 0)
106
        return; // This is run entirely on one device
107
    ContextSelector selector(cc);
108
    int numParticles = force.getNumParticles();
109
110
111
112
    if (numParticles > 0) {
        // Create the harmonic interaction .
        
        vector<vector<int> > atoms(numParticles, vector<int>(5));
113
        particleParams.initialize<mm_float4>(cc, numParticles, "drudeParticleParams");
114
115
116
        vector<mm_float4> paramVector(numParticles);
        for (int i = 0; i < numParticles; i++) {
            double charge, polarizability, aniso12, aniso34;
117
            force.getParticleParameters(i, atoms[i][0], atoms[i][1], atoms[i][2], atoms[i][3], atoms[i][4], charge, polarizability, aniso12, aniso34);
118
119
120
            double a1 = (atoms[i][2] == -1 ? 1 : aniso12);
            double a2 = (atoms[i][3] == -1 || atoms[i][4] == -1 ? 1 : aniso34);
            double a3 = 3-a1-a2;
121
122
123
            double k3 = ONE_4PI_EPS0*charge*charge/(polarizability*a3);
            double k1 = ONE_4PI_EPS0*charge*charge/(polarizability*a1) - k3;
            double k2 = ONE_4PI_EPS0*charge*charge/(polarizability*a2) - k3;
124
            if (atoms[i][2] == -1) {
125
                atoms[i][2] = atoms[i][0];
126
127
128
                k1 = 0;
            }
            if (atoms[i][3] == -1 || atoms[i][4] == -1) {
129
130
                atoms[i][3] = atoms[i][0];
                atoms[i][4] = atoms[i][0];
131
132
133
134
                k2 = 0;
            }
            paramVector[i] = mm_float4((float) k1, (float) k2, (float) k3, 0.0f);
        }
peastman's avatar
peastman committed
135
        particleParams.upload(paramVector);
136
        map<string, string> replacements;
137
138
        replacements["PARAMS"] = cc.getBondedUtilities().addArgument(particleParams, "float4");
        cc.getBondedUtilities().addInteraction(atoms, cc.replaceStrings(CommonDrudeKernelSources::drudeParticleForce, replacements), force.getForceGroup());
139
    }
140
    int numPairs = force.getNumScreenedPairs();
141
142
143
144
    if (numPairs > 0) {
        // Create the screened interaction between dipole pairs.
        
        vector<vector<int> > atoms(numPairs, vector<int>(4));
145
        pairParams.initialize<mm_float2>(cc, numPairs, "drudePairParams");
146
147
148
149
        vector<mm_float2> paramVector(numPairs);
        for (int i = 0; i < numPairs; i++) {
            int drude1, drude2;
            double thole;
150
            force.getScreenedPairParameters(i, drude1, drude2, thole);
151
152
153
154
155
156
157
158
            int p2, p3, p4;
            double charge1, charge2, polarizability1, polarizability2, aniso12, aniso34;
            force.getParticleParameters(drude1, atoms[i][0], atoms[i][1], p2, p3, p4, charge1, polarizability1, aniso12, aniso34);
            force.getParticleParameters(drude2, atoms[i][2], atoms[i][3], p2, p3, p4, charge2, polarizability2, aniso12, aniso34);
            double screeningScale = thole/pow(polarizability1*polarizability2, 1.0/6.0);
            double energyScale = ONE_4PI_EPS0*charge1*charge2;
            paramVector[i] = mm_float2((float) screeningScale, (float) energyScale);
        }
peastman's avatar
peastman committed
159
        pairParams.upload(paramVector);
160
        map<string, string> replacements;
161
162
        replacements["PARAMS"] = cc.getBondedUtilities().addArgument(pairParams, "float2");
        cc.getBondedUtilities().addInteraction(atoms, cc.replaceStrings(CommonDrudeKernelSources::drudePairForce, replacements), force.getForceGroup());
163
    }
164
    cc.addForce(new CommonDrudeForceInfo(force));
165
166
}

167
double CommonCalcDrudeForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
168
169
170
    return 0.0;
}

171
172
void CommonCalcDrudeForceKernel::copyParametersToContext(ContextImpl& context, const DrudeForce& force) {
    if (cc.getContextIndex() != 0)
173
        return; // This is run entirely on one device
174
175
    
    // Set the particle parameters.
176
    
177
    ContextSelector selector(cc);
178
    int numParticles = force.getNumParticles();
179
    if (numParticles > 0) {
peastman's avatar
peastman committed
180
        if (!particleParams.isInitialized() || numParticles != particleParams.getSize())
181
182
183
184
185
            throw OpenMMException("updateParametersInContext: The number of Drude particles has changed");
        vector<mm_float4> paramVector(numParticles);
        for (int i = 0; i < numParticles; i++) {
            int p, p1, p2, p3, p4;
            double charge, polarizability, aniso12, aniso34;
186
            force.getParticleParameters(i, p, p1, p2, p3, p4, charge, polarizability, aniso12, aniso34);
187
188
189
            double a1 = (p2 == -1 ? 1 : aniso12);
            double a2 = (p3 == -1 || p4 == -1 ? 1 : aniso34);
            double a3 = 3-a1-a2;
190
191
192
            double k3 = ONE_4PI_EPS0*charge*charge/(polarizability*a3);
            double k1 = ONE_4PI_EPS0*charge*charge/(polarizability*a1) - k3;
            double k2 = ONE_4PI_EPS0*charge*charge/(polarizability*a2) - k3;
193
194
195
196
197
198
            if (p2 == -1)
                k1 = 0;
            if (p3 == -1 || p4 == -1)
                k2 = 0;
            paramVector[i] = mm_float4((float) k1, (float) k2, (float) k3, 0.0f);
        }
peastman's avatar
peastman committed
199
        particleParams.upload(paramVector);
200
201
202
203
    }
    
    // Set the pair parameters.
    
204
    int numPairs = force.getNumScreenedPairs();
205
    if (numPairs > 0) {
peastman's avatar
peastman committed
206
        if (!pairParams.isInitialized() || numPairs != pairParams.getSize())
207
208
209
210
211
            throw OpenMMException("updateParametersInContext: The number of screened pairs has changed");
        vector<mm_float2> paramVector(numPairs);
        for (int i = 0; i < numPairs; i++) {
            int drude1, drude2;
            double thole;
212
            force.getScreenedPairParameters(i, drude1, drude2, thole);
213
214
215
216
217
218
219
220
            int p, p1, p2, p3, p4;
            double charge1, charge2, polarizability1, polarizability2, aniso12, aniso34;
            force.getParticleParameters(drude1, p, p1, p2, p3, p4, charge1, polarizability1, aniso12, aniso34);
            force.getParticleParameters(drude2, p, p1, p2, p3, p4, charge2, polarizability2, aniso12, aniso34);
            double screeningScale = thole/pow(polarizability1*polarizability2, 1.0/6.0);
            double energyScale = ONE_4PI_EPS0*charge1*charge2;
            paramVector[i] = mm_float2((float) screeningScale, (float) energyScale);
        }
peastman's avatar
peastman committed
221
        pairParams.upload(paramVector);
222
    }
223
224
}

225
226
void CommonIntegrateDrudeLangevinStepKernel::initialize(const System& system, const DrudeLangevinIntegrator& integrator, const DrudeForce& force) {
    cc.initializeContexts();
227
    ContextSelector selector(cc);
228
    cc.getIntegrationUtilities().initRandomNumberGenerator((unsigned int) integrator.getRandomNumberSeed());
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
    
    // Identify particle pairs and ordinary particles.
    
    set<int> particles;
    vector<int> normalParticleVec;
    vector<mm_int2> pairParticleVec;
    for (int i = 0; i < system.getNumParticles(); i++)
        particles.insert(i);
    for (int i = 0; i < force.getNumParticles(); i++) {
        int p, p1, p2, p3, p4;
        double charge, polarizability, aniso12, aniso34;
        force.getParticleParameters(i, p, p1, p2, p3, p4, charge, polarizability, aniso12, aniso34);
        particles.erase(p);
        particles.erase(p1);
        pairParticleVec.push_back(mm_int2(p, p1));
    }
    normalParticleVec.insert(normalParticleVec.begin(), particles.begin(), particles.end());
246
247
    normalParticles.initialize<int>(cc, max((int) normalParticleVec.size(), 1), "drudeNormalParticles");
    pairParticles.initialize<mm_int2>(cc, max((int) pairParticleVec.size(), 1), "drudePairParticles");
248
    if (normalParticleVec.size() > 0)
peastman's avatar
peastman committed
249
        normalParticles.upload(normalParticleVec);
250
    if (pairParticleVec.size() > 0)
peastman's avatar
peastman committed
251
        pairParticles.upload(pairParticleVec);
252
253
254
255

    // Create kernels.
    
    map<string, string> defines;
256
257
258
259
    defines["NUM_ATOMS"] = cc.intToString(cc.getNumAtoms());
    defines["PADDED_NUM_ATOMS"] = cc.intToString(cc.getPaddedNumAtoms());
    defines["NUM_NORMAL_PARTICLES"] = cc.intToString(normalParticleVec.size());
    defines["NUM_PAIRS"] = cc.intToString(pairParticleVec.size());
260
    map<string, string> replacements;
261
262
263
264
    ComputeProgram program = cc.compileProgram(CommonDrudeKernelSources::drudeLangevin, defines);
    kernel1 = program->createKernel("integrateDrudeLangevinPart1");
    kernel2 = program->createKernel("integrateDrudeLangevinPart2");
    hardwallKernel = program->createKernel("applyHardWallConstraints");
265
266
267
    prevStepSize = -1.0;
}

268
void CommonIntegrateDrudeLangevinStepKernel::execute(ContextImpl& context, const DrudeLangevinIntegrator& integrator) {
269
    ContextSelector selector(cc);
270
271
    IntegrationUtilities& integration = cc.getIntegrationUtilities();
    int numAtoms = cc.getNumAtoms();
272
273
    if (!hasInitializedKernels) {
        hasInitializedKernels = true;
274
275
276
277
278
279
280
281
282
283
284
285
286
        kernel1->addArg(cc.getVelm());
        kernel1->addArg(cc.getLongForceBuffer());
        kernel1->addArg(integration.getPosDelta());
        kernel1->addArg(normalParticles);
        kernel1->addArg(pairParticles);
        kernel1->addArg(integration.getStepSize());
        for (int i = 0; i < 6; i++)
            kernel1->addArg();
        kernel1->addArg(integration.getRandom());
        kernel1->addArg();
        kernel2->addArg(cc.getPosq());
        if (cc.getUseMixedPrecision())
            kernel2->addArg(cc.getPosqCorrection());
287
        else
288
            kernel2->addArg(nullptr);
289
290
291
292
293
294
        kernel2->addArg(integration.getPosDelta());
        kernel2->addArg(cc.getVelm());
        kernel2->addArg(integration.getStepSize());
        hardwallKernel->addArg(cc.getPosq());
        if (cc.getUseMixedPrecision())
            hardwallKernel->addArg(cc.getPosqCorrection());
295
        else
296
            hardwallKernel->addArg(nullptr);
297
298
299
300
301
        hardwallKernel->addArg(cc.getVelm());
        hardwallKernel->addArg(pairParticles);
        hardwallKernel->addArg(integration.getStepSize());
        hardwallKernel->addArg();
        hardwallKernel->addArg();
302
303
304
305
306
307
    }
    
    // Compute integrator coefficients.
    
    double stepSize = integrator.getStepSize();
    double vscale = exp(-stepSize*integrator.getFriction());
308
    double fscale = (1-vscale)/integrator.getFriction()/(double) 0x100000000;
309
310
    double noisescale = sqrt(2*BOLTZ*integrator.getTemperature()*integrator.getFriction())*sqrt(0.5*(1-vscale*vscale)/integrator.getFriction());
    double vscaleDrude = exp(-stepSize*integrator.getDrudeFriction());
311
    double fscaleDrude = (1-vscaleDrude)/integrator.getDrudeFriction()/(double) 0x100000000;
312
    double noisescaleDrude = sqrt(2*BOLTZ*integrator.getDrudeTemperature()*integrator.getDrudeFriction())*sqrt(0.5*(1-vscaleDrude*vscaleDrude)/integrator.getDrudeFriction());
313
314
    double maxDrudeDistance = integrator.getMaxDrudeDistance();
    double hardwallscaleDrude = sqrt(BOLTZ*integrator.getDrudeTemperature());
315
    if (stepSize != prevStepSize) {
316
        if (cc.getUseDoublePrecision() || cc.getUseMixedPrecision()) {
317
318
319
320
321
322
323
324
325
            mm_double2 ss = mm_double2(0, stepSize);
            integration.getStepSize().upload(&ss);
        }
        else {
            mm_float2 ss = mm_float2(0, (float) stepSize);
            integration.getStepSize().upload(&ss);
        }
        prevStepSize = stepSize;
    }
326
327
328
329
330
331
332
333
334
    if (cc.getUseDoublePrecision() || cc.getUseMixedPrecision()) {
            kernel1->setArg(6, vscale);
            kernel1->setArg(7, fscale);
            kernel1->setArg(8, noisescale);
            kernel1->setArg(9, vscaleDrude);
            kernel1->setArg(10, fscaleDrude);
            kernel1->setArg(11, noisescaleDrude);
            hardwallKernel->setArg(5, maxDrudeDistance);
            hardwallKernel->setArg(6, hardwallscaleDrude);
335
336
    }
    else {
337
338
339
340
341
342
343
344
            kernel1->setArg(6, (float) vscale);
            kernel1->setArg(7, (float) fscale);
            kernel1->setArg(8, (float) noisescale);
            kernel1->setArg(9, (float) vscaleDrude);
            kernel1->setArg(10, (float) fscaleDrude);
            kernel1->setArg(11, (float) noisescaleDrude);
            hardwallKernel->setArg(5, (float) maxDrudeDistance);
            hardwallKernel->setArg(6, (float) hardwallscaleDrude);
345
346
347
348
    }

    // Call the first integration kernel.

349
350
    kernel1->setArg(13, integration.prepareRandomNumbers(normalParticles.getSize()+2*pairParticles.getSize()));
    kernel1->execute(numAtoms);
351
352
353
354
355
356
357

    // Apply constraints.

    integration.applyConstraints(integrator.getConstraintTolerance());

    // Call the second integration kernel.

358
    kernel2->execute(numAtoms);
359
360
361
    
    // Apply hard wall constraints.
    
362
    if (maxDrudeDistance > 0)
363
        hardwallKernel->execute(pairParticles.getSize());
364
365
366
367
    integration.computeVirtualSites();

    // Update the time and step count.

368
369
370
    cc.setTime(cc.getTime()+stepSize);
    cc.setStepCount(cc.getStepCount()+1);
    cc.reorderAtoms();
371
372
}

373
374
double CommonIntegrateDrudeLangevinStepKernel::computeKineticEnergy(ContextImpl& context, const DrudeLangevinIntegrator& integrator) {
    return cc.getIntegrationUtilities().computeKineticEnergy(0.5*integrator.getStepSize());
375
}
376

377
CommonIntegrateDrudeSCFStepKernel::~CommonIntegrateDrudeSCFStepKernel() {
378
379
380
381
    if (minimizerPos != NULL)
        lbfgs_free(minimizerPos);
}

382
383
void CommonIntegrateDrudeSCFStepKernel::initialize(const System& system, const DrudeSCFIntegrator& integrator, const DrudeForce& force) {
    cc.initializeContexts();
384
    ContextSelector selector(cc);
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404

    // Identify Drude particles.
    
    for (int i = 0; i < force.getNumParticles(); i++) {
        int p, p1, p2, p3, p4;
        double charge, polarizability, aniso12, aniso34;
        force.getParticleParameters(i, p, p1, p2, p3, p4, charge, polarizability, aniso12, aniso34);
        drudeParticles.push_back(p);
    }
    
    // Initialize the energy minimizer.
    
    minimizerPos = lbfgs_malloc(drudeParticles.size()*3);
    if (minimizerPos == NULL)
        throw OpenMMException("DrudeSCFIntegrator: Failed to allocate memory");
    lbfgs_parameter_init(&minimizerParams);
    minimizerParams.linesearch = LBFGS_LINESEARCH_BACKTRACKING_STRONG_WOLFE;    

    // Create the kernels.
    
405
406
407
    ComputeProgram program = cc.compileProgram(CommonKernelSources::verlet);
    kernel1 = program->createKernel("integrateVerletPart1");
    kernel2 = program->createKernel("integrateVerletPart2");
408
409
410
    prevStepSize = -1.0;
}

411
void CommonIntegrateDrudeSCFStepKernel::execute(ContextImpl& context, const DrudeSCFIntegrator& integrator) {
412
    ContextSelector selector(cc);
413
414
    IntegrationUtilities& integration = cc.getIntegrationUtilities();
    int numAtoms = cc.getNumAtoms();
415
416
417
    double dt = integrator.getStepSize();
    if (!hasInitializedKernels) {
        hasInitializedKernels = true;
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
        kernel1->addArg(numAtoms);
        kernel1->addArg(cc.getPaddedNumAtoms());
        kernel1->addArg(cc.getIntegrationUtilities().getStepSize());
        kernel1->addArg(cc.getPosq());
        kernel1->addArg(cc.getVelm());
        kernel1->addArg(cc.getLongForceBuffer());
        kernel1->addArg(integration.getPosDelta());
        if (cc.getUseMixedPrecision())
            kernel1->addArg(cc.getPosqCorrection());
        kernel2->addArg(numAtoms);
        kernel2->addArg(cc.getIntegrationUtilities().getStepSize());
        kernel2->addArg(cc.getPosq());
        kernel2->addArg(cc.getVelm());
        kernel2->addArg(integration.getPosDelta());
        if (cc.getUseMixedPrecision())
            kernel2->addArg(cc.getPosqCorrection());
434
435
    }
    if (dt != prevStepSize) {
436
        if (cc.getUseDoublePrecision() || cc.getUseMixedPrecision()) {
437
438
            vector<mm_double2> stepSizeVec(1);
            stepSizeVec[0] = mm_double2(dt, dt);
439
            cc.getIntegrationUtilities().getStepSize().upload(stepSizeVec);
440
441
442
        }
        else {
            vector<mm_float2> stepSizeVec(1);
443
444
            stepSizeVec[0] = mm_float2((float) dt, (float) dt);
            cc.getIntegrationUtilities().getStepSize().upload(stepSizeVec);
445
446
447
448
449
450
        }
        prevStepSize = dt;
    }

    // Call the first integration kernel.

451
    kernel1->execute(numAtoms);
452
453
454
455
456
457
458

    // Apply constraints.

    integration.applyConstraints(integrator.getConstraintTolerance());

    // Call the second integration kernel.

459
    kernel2->execute(numAtoms);
460
461
462
463
464
465
466
467

    // Update the positions of virtual sites and Drude particles.

    integration.computeVirtualSites();
    minimize(context, integrator.getMinimizationErrorTolerance());

    // Update the time and step count.

468
469
470
    cc.setTime(cc.getTime()+dt);
    cc.setStepCount(cc.getStepCount()+1);
    cc.reorderAtoms();
471
472
473
474
    
    // Reduce UI lag.
    
#ifdef WIN32
475
    cc.flushQueue();
476
477
478
#endif
}

479
480
double CommonIntegrateDrudeSCFStepKernel::computeKineticEnergy(ContextImpl& context, const DrudeSCFIntegrator& integrator) {
    return cc.getIntegrationUtilities().computeKineticEnergy(0.5*integrator.getStepSize());
481
482
483
484
}

struct MinimizerData {
    ContextImpl& context;
485
    ComputeContext& cc;
486
    vector<int>& drudeParticles;
487
    MinimizerData(ContextImpl& context, ComputeContext& cc, vector<int>& drudeParticles) : context(context), cc(cc), drudeParticles(drudeParticles) {}
488
489
490
491
492
};

static lbfgsfloatval_t evaluate(void *instance, const lbfgsfloatval_t *x, lbfgsfloatval_t *g, const int n, const lbfgsfloatval_t step) {
    MinimizerData* data = reinterpret_cast<MinimizerData*>(instance);
    ContextImpl& context = data->context;
493
    ComputeContext& cc = data->cc;
494
495
    vector<int>& drudeParticles = data->drudeParticles;
    int numDrudeParticles = drudeParticles.size();
496
497
498

    // Set the particle positions.
    
499
500
501
    cc.getPosq().download(cc.getPinnedBuffer());
    if (cc.getUseDoublePrecision()) {
        mm_double4* posq = (mm_double4*) cc.getPinnedBuffer();
502
        for (int i = 0; i < numDrudeParticles; ++i) {
503
504
505
506
            mm_double4& p = posq[drudeParticles[i]];
            p.x = x[3*i];
            p.y = x[3*i+1];
            p.z = x[3*i+2];
507
508
509
        }
    }
    else {
510
        mm_float4* posq = (mm_float4*) cc.getPinnedBuffer();
511
        for (int i = 0; i < numDrudeParticles; ++i) {
512
513
514
515
            mm_float4& p = posq[drudeParticles[i]];
            p.x = x[3*i];
            p.y = x[3*i+1];
            p.z = x[3*i+2];
516
517
        }
    }
518
    cc.getPosq().upload(cc.getPinnedBuffer());
519
520
521

    // Compute the forces and energy for this configuration.

522
    double energy = context.calcForcesAndEnergy(true, true, context.getIntegrator().getIntegrationForceGroups());
523
524
525
526
527
528
529
530
531
    long long* force = (long long*) cc.getPinnedBuffer();
    cc.getLongForceBuffer().download(force);
    double forceScale = -1.0/0x100000000;
    int paddedNumAtoms = cc.getPaddedNumAtoms();
    for (int i = 0; i < numDrudeParticles; ++i) {
        int index = drudeParticles[i];
        g[3*i] = forceScale*force[index];
        g[3*i+1] = forceScale*force[index+paddedNumAtoms];
        g[3*i+2] = forceScale*force[index+paddedNumAtoms*2];
532
533
534
535
    }
    return energy;
}

536
void CommonIntegrateDrudeSCFStepKernel::minimize(ContextImpl& context, double tolerance) {
537
538
    // Record the initial positions.

539
    int numDrudeParticles = drudeParticles.size();
540
541
542
    cc.getPosq().download(cc.getPinnedBuffer());
    if (cc.getUseDoublePrecision()) {
        mm_double4* posq = (mm_double4*) cc.getPinnedBuffer();
543
        for (int i = 0; i < numDrudeParticles; ++i) {
544
545
546
547
            mm_double4 p = posq[drudeParticles[i]];
            minimizerPos[3*i] = p.x;
            minimizerPos[3*i+1] = p.y;
            minimizerPos[3*i+2] = p.z;
548
549
550
        }
    }
    else {
551
        mm_float4* posq = (mm_float4*) cc.getPinnedBuffer();
552
        for (int i = 0; i < numDrudeParticles; ++i) {
553
554
555
556
            mm_float4 p = posq[drudeParticles[i]];
            minimizerPos[3*i] = p.x;
            minimizerPos[3*i+1] = p.y;
            minimizerPos[3*i+2] = p.z;
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
        }
        minimizerParams.xtol = 1e-7;
    }
    
    // Determine a normalization constant for scaling the tolerance.
    
    double norm = 0.0;
    for (int i = 0; i < 3*numDrudeParticles; i++)
        norm += minimizerPos[i]*minimizerPos[i];
    norm /= numDrudeParticles;
    norm = (norm < 1 ? 1 : sqrt(norm));
    minimizerParams.epsilon = tolerance/norm;
    
    // Perform the minimization.

    lbfgsfloatval_t fx;
573
    MinimizerData data(context, cc, drudeParticles);
574
575
    lbfgs(numDrudeParticles*3, minimizerPos, &fx, evaluate, NULL, &data, &minimizerParams);
}