CommonDrudeKernels.cpp 24.9 KB
Newer Older
1
2
3
4
5
6
7
8
/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
 * This is part of the OpenMM molecular simulation toolkit originating from   *
 * Simbios, the NIH National Center for Physics-Based Simulation of           *
 * Biological Structures at Stanford, funded under the NIH Roadmap for        *
 * Medical Research, grant U54 GM072970. See https://simtk.org.               *
 *                                                                            *
9
 * Portions copyright (c) 2013-2019 Stanford University and the Authors.      *
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
 * Authors: Peter Eastman                                                     *
 * Contributors:                                                              *
 *                                                                            *
 * Permission is hereby granted, free of charge, to any person obtaining a    *
 * copy of this software and associated documentation files (the "Software"), *
 * to deal in the Software without restriction, including without limitation  *
 * the rights to use, copy, modify, merge, publish, distribute, sublicense,   *
 * and/or sell copies of the Software, and to permit persons to whom the      *
 * Software is furnished to do so, subject to the following conditions:       *
 *                                                                            *
 * The above copyright notice and this permission notice shall be included in *
 * all copies or substantial portions of the Software.                        *
 *                                                                            *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,   *
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL    *
 * THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,    *
 * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR      *
 * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE  *
 * USE OR OTHER DEALINGS IN THE SOFTWARE.                                     *
 * -------------------------------------------------------------------------- */

32
33
#include "CommonDrudeKernels.h"
#include "CommonDrudeKernelSources.h"
34
#include "openmm/internal/ContextImpl.h"
35
36
37
38
#include "openmm/common/BondedUtilities.h"
#include "openmm/common/ComputeForceInfo.h"
#include "openmm/common/IntegrationUtilities.h"
#include "CommonKernelSources.h"
39
#include "SimTKOpenMMRealType.h"
40
41
42
43
44
#include <set>

using namespace OpenMM;
using namespace std;

45
class CommonDrudeForceInfo : public ComputeForceInfo {
46
public:
47
    CommonDrudeForceInfo(const DrudeForce& force) : force(force) {
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
    }
    int getNumParticleGroups() {
        return force.getNumParticles()+force.getNumScreenedPairs();
    }
    void getParticlesInGroup(int index, vector<int>& particles) {
        particles.clear();
        if (index < force.getNumParticles()) {
            int p, p1, p2, p3, p4;
            double charge, polarizability, aniso12, aniso34;
            force.getParticleParameters(index, p, p1, p2, p3, p4, charge, polarizability, aniso12, aniso34);
            particles.push_back(p);
            particles.push_back(p1);
            if (p2 != -1)
                particles.push_back(p2);
            if (p3 != -1)
                particles.push_back(p3);
            if (p4 != -1)
                particles.push_back(p4);
        }
        else {
            int drude1, drude2;
            double thole;
            force.getScreenedPairParameters(index-force.getNumParticles(), drude1, drude2, thole);
            int p, p1, p2, p3, p4;
            double charge, polarizability, aniso12, aniso34;
            force.getParticleParameters(drude1, p, p1, p2, p3, p4, charge, polarizability, aniso12, aniso34);
            particles.push_back(p);
            particles.push_back(p1);
            force.getParticleParameters(drude2, p, p1, p2, p3, p4, charge, polarizability, aniso12, aniso34);
            particles.push_back(p);
            particles.push_back(p1);
        }
    }
    bool areGroupsIdentical(int group1, int group2) {
        if (group1 < force.getNumParticles() && group2 < force.getNumParticles()) {
            int p, p1, p2, p3, p4;
            double charge1, polarizability1, aniso12_1, aniso34_1;
            double charge2, polarizability2, aniso12_2, aniso34_2;
            force.getParticleParameters(group1, p, p1, p2, p3, p4, charge1, polarizability1, aniso12_1, aniso34_1);
            force.getParticleParameters(group2, p, p1, p2, p3, p4, charge2, polarizability2, aniso12_2, aniso34_2);
            return (charge1 == charge2 && polarizability1 == polarizability2 && aniso12_1 == aniso12_2 && aniso34_1 == aniso34_2);
        }
        if (group1 >= force.getNumParticles() && group2 >= force.getNumParticles()) {
            int drude1, drude2;
            double thole1, thole2;
            force.getScreenedPairParameters(group1-force.getNumParticles(), drude1, drude2, thole1);
            force.getScreenedPairParameters(group1-force.getNumParticles(), drude1, drude2, thole2);
            return (thole1 == thole2);
        }
        return false;
    }
private:
    const DrudeForce& force;
};

103
104
105
void CommonCalcDrudeForceKernel::initialize(const System& system, const DrudeForce& force) {
    cc.setAsCurrent();
    if (cc.getContextIndex() != 0)
106
107
        return; // This is run entirely on one device
    int numParticles = force.getNumParticles();
108
109
110
111
    if (numParticles > 0) {
        // Create the harmonic interaction .
        
        vector<vector<int> > atoms(numParticles, vector<int>(5));
112
        particleParams.initialize<mm_float4>(cc, numParticles, "drudeParticleParams");
113
114
115
        vector<mm_float4> paramVector(numParticles);
        for (int i = 0; i < numParticles; i++) {
            double charge, polarizability, aniso12, aniso34;
116
            force.getParticleParameters(i, atoms[i][0], atoms[i][1], atoms[i][2], atoms[i][3], atoms[i][4], charge, polarizability, aniso12, aniso34);
117
118
119
            double a1 = (atoms[i][2] == -1 ? 1 : aniso12);
            double a2 = (atoms[i][3] == -1 || atoms[i][4] == -1 ? 1 : aniso34);
            double a3 = 3-a1-a2;
120
121
122
            double k3 = ONE_4PI_EPS0*charge*charge/(polarizability*a3);
            double k1 = ONE_4PI_EPS0*charge*charge/(polarizability*a1) - k3;
            double k2 = ONE_4PI_EPS0*charge*charge/(polarizability*a2) - k3;
123
            if (atoms[i][2] == -1) {
124
                atoms[i][2] = atoms[i][0];
125
126
127
                k1 = 0;
            }
            if (atoms[i][3] == -1 || atoms[i][4] == -1) {
128
129
                atoms[i][3] = atoms[i][0];
                atoms[i][4] = atoms[i][0];
130
131
132
133
                k2 = 0;
            }
            paramVector[i] = mm_float4((float) k1, (float) k2, (float) k3, 0.0f);
        }
peastman's avatar
peastman committed
134
        particleParams.upload(paramVector);
135
        map<string, string> replacements;
136
137
        replacements["PARAMS"] = cc.getBondedUtilities().addArgument(particleParams, "float4");
        cc.getBondedUtilities().addInteraction(atoms, cc.replaceStrings(CommonDrudeKernelSources::drudeParticleForce, replacements), force.getForceGroup());
138
    }
139
    int numPairs = force.getNumScreenedPairs();
140
141
142
143
    if (numPairs > 0) {
        // Create the screened interaction between dipole pairs.
        
        vector<vector<int> > atoms(numPairs, vector<int>(4));
144
        pairParams.initialize<mm_float2>(cc, numPairs, "drudePairParams");
145
146
147
148
        vector<mm_float2> paramVector(numPairs);
        for (int i = 0; i < numPairs; i++) {
            int drude1, drude2;
            double thole;
149
            force.getScreenedPairParameters(i, drude1, drude2, thole);
150
151
152
153
154
155
156
157
            int p2, p3, p4;
            double charge1, charge2, polarizability1, polarizability2, aniso12, aniso34;
            force.getParticleParameters(drude1, atoms[i][0], atoms[i][1], p2, p3, p4, charge1, polarizability1, aniso12, aniso34);
            force.getParticleParameters(drude2, atoms[i][2], atoms[i][3], p2, p3, p4, charge2, polarizability2, aniso12, aniso34);
            double screeningScale = thole/pow(polarizability1*polarizability2, 1.0/6.0);
            double energyScale = ONE_4PI_EPS0*charge1*charge2;
            paramVector[i] = mm_float2((float) screeningScale, (float) energyScale);
        }
peastman's avatar
peastman committed
158
        pairParams.upload(paramVector);
159
        map<string, string> replacements;
160
161
        replacements["PARAMS"] = cc.getBondedUtilities().addArgument(pairParams, "float2");
        cc.getBondedUtilities().addInteraction(atoms, cc.replaceStrings(CommonDrudeKernelSources::drudePairForce, replacements), force.getForceGroup());
162
    }
163
    cc.addForce(new CommonDrudeForceInfo(force));
164
165
}

166
double CommonCalcDrudeForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
167
168
169
    return 0.0;
}

170
171
void CommonCalcDrudeForceKernel::copyParametersToContext(ContextImpl& context, const DrudeForce& force) {
    if (cc.getContextIndex() != 0)
172
        return; // This is run entirely on one device
173
174
    
    // Set the particle parameters.
175
    
176
    int numParticles = force.getNumParticles();
177
    if (numParticles > 0) {
peastman's avatar
peastman committed
178
        if (!particleParams.isInitialized() || numParticles != particleParams.getSize())
179
180
181
182
183
            throw OpenMMException("updateParametersInContext: The number of Drude particles has changed");
        vector<mm_float4> paramVector(numParticles);
        for (int i = 0; i < numParticles; i++) {
            int p, p1, p2, p3, p4;
            double charge, polarizability, aniso12, aniso34;
184
            force.getParticleParameters(i, p, p1, p2, p3, p4, charge, polarizability, aniso12, aniso34);
185
186
187
            double a1 = (p2 == -1 ? 1 : aniso12);
            double a2 = (p3 == -1 || p4 == -1 ? 1 : aniso34);
            double a3 = 3-a1-a2;
188
189
190
            double k3 = ONE_4PI_EPS0*charge*charge/(polarizability*a3);
            double k1 = ONE_4PI_EPS0*charge*charge/(polarizability*a1) - k3;
            double k2 = ONE_4PI_EPS0*charge*charge/(polarizability*a2) - k3;
191
192
193
194
195
196
            if (p2 == -1)
                k1 = 0;
            if (p3 == -1 || p4 == -1)
                k2 = 0;
            paramVector[i] = mm_float4((float) k1, (float) k2, (float) k3, 0.0f);
        }
peastman's avatar
peastman committed
197
        particleParams.upload(paramVector);
198
199
200
201
    }
    
    // Set the pair parameters.
    
202
    int numPairs = force.getNumScreenedPairs();
203
    if (numPairs > 0) {
peastman's avatar
peastman committed
204
        if (!pairParams.isInitialized() || numPairs != pairParams.getSize())
205
206
207
208
209
            throw OpenMMException("updateParametersInContext: The number of screened pairs has changed");
        vector<mm_float2> paramVector(numPairs);
        for (int i = 0; i < numPairs; i++) {
            int drude1, drude2;
            double thole;
210
            force.getScreenedPairParameters(i, drude1, drude2, thole);
211
212
213
214
215
216
217
218
            int p, p1, p2, p3, p4;
            double charge1, charge2, polarizability1, polarizability2, aniso12, aniso34;
            force.getParticleParameters(drude1, p, p1, p2, p3, p4, charge1, polarizability1, aniso12, aniso34);
            force.getParticleParameters(drude2, p, p1, p2, p3, p4, charge2, polarizability2, aniso12, aniso34);
            double screeningScale = thole/pow(polarizability1*polarizability2, 1.0/6.0);
            double energyScale = ONE_4PI_EPS0*charge1*charge2;
            paramVector[i] = mm_float2((float) screeningScale, (float) energyScale);
        }
peastman's avatar
peastman committed
219
        pairParams.upload(paramVector);
220
    }
221
222
}

223
224
225
void CommonIntegrateDrudeLangevinStepKernel::initialize(const System& system, const DrudeLangevinIntegrator& integrator, const DrudeForce& force) {
    cc.initializeContexts();
    cc.getIntegrationUtilities().initRandomNumberGenerator((unsigned int) integrator.getRandomNumberSeed());
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
    
    // Identify particle pairs and ordinary particles.
    
    set<int> particles;
    vector<int> normalParticleVec;
    vector<mm_int2> pairParticleVec;
    for (int i = 0; i < system.getNumParticles(); i++)
        particles.insert(i);
    for (int i = 0; i < force.getNumParticles(); i++) {
        int p, p1, p2, p3, p4;
        double charge, polarizability, aniso12, aniso34;
        force.getParticleParameters(i, p, p1, p2, p3, p4, charge, polarizability, aniso12, aniso34);
        particles.erase(p);
        particles.erase(p1);
        pairParticleVec.push_back(mm_int2(p, p1));
    }
    normalParticleVec.insert(normalParticleVec.begin(), particles.begin(), particles.end());
243
244
    normalParticles.initialize<int>(cc, max((int) normalParticleVec.size(), 1), "drudeNormalParticles");
    pairParticles.initialize<mm_int2>(cc, max((int) pairParticleVec.size(), 1), "drudePairParticles");
245
    if (normalParticleVec.size() > 0)
peastman's avatar
peastman committed
246
        normalParticles.upload(normalParticleVec);
247
    if (pairParticleVec.size() > 0)
peastman's avatar
peastman committed
248
        pairParticles.upload(pairParticleVec);
249
250
251
252

    // Create kernels.
    
    map<string, string> defines;
253
254
255
256
    defines["NUM_ATOMS"] = cc.intToString(cc.getNumAtoms());
    defines["PADDED_NUM_ATOMS"] = cc.intToString(cc.getPaddedNumAtoms());
    defines["NUM_NORMAL_PARTICLES"] = cc.intToString(normalParticleVec.size());
    defines["NUM_PAIRS"] = cc.intToString(pairParticleVec.size());
257
    map<string, string> replacements;
258
259
260
261
    ComputeProgram program = cc.compileProgram(CommonDrudeKernelSources::drudeLangevin, defines);
    kernel1 = program->createKernel("integrateDrudeLangevinPart1");
    kernel2 = program->createKernel("integrateDrudeLangevinPart2");
    hardwallKernel = program->createKernel("applyHardWallConstraints");
262
263
264
    prevStepSize = -1.0;
}

265
266
267
268
void CommonIntegrateDrudeLangevinStepKernel::execute(ContextImpl& context, const DrudeLangevinIntegrator& integrator) {
    cc.setAsCurrent();
    IntegrationUtilities& integration = cc.getIntegrationUtilities();
    int numAtoms = cc.getNumAtoms();
269
270
    if (!hasInitializedKernels) {
        hasInitializedKernels = true;
271
272
273
274
275
276
277
278
279
280
281
282
283
        kernel1->addArg(cc.getVelm());
        kernel1->addArg(cc.getLongForceBuffer());
        kernel1->addArg(integration.getPosDelta());
        kernel1->addArg(normalParticles);
        kernel1->addArg(pairParticles);
        kernel1->addArg(integration.getStepSize());
        for (int i = 0; i < 6; i++)
            kernel1->addArg();
        kernel1->addArg(integration.getRandom());
        kernel1->addArg();
        kernel2->addArg(cc.getPosq());
        if (cc.getUseMixedPrecision())
            kernel2->addArg(cc.getPosqCorrection());
284
        else
285
286
287
288
289
290
291
            kernel2->addArg(NULL);
        kernel2->addArg(integration.getPosDelta());
        kernel2->addArg(cc.getVelm());
        kernel2->addArg(integration.getStepSize());
        hardwallKernel->addArg(cc.getPosq());
        if (cc.getUseMixedPrecision())
            hardwallKernel->addArg(cc.getPosqCorrection());
292
        else
293
294
295
296
297
298
            hardwallKernel->addArg(NULL);
        hardwallKernel->addArg(cc.getVelm());
        hardwallKernel->addArg(pairParticles);
        hardwallKernel->addArg(integration.getStepSize());
        hardwallKernel->addArg();
        hardwallKernel->addArg();
299
300
301
302
303
304
    }
    
    // Compute integrator coefficients.
    
    double stepSize = integrator.getStepSize();
    double vscale = exp(-stepSize*integrator.getFriction());
305
    double fscale = (1-vscale)/integrator.getFriction()/(double) 0x100000000;
306
307
    double noisescale = sqrt(2*BOLTZ*integrator.getTemperature()*integrator.getFriction())*sqrt(0.5*(1-vscale*vscale)/integrator.getFriction());
    double vscaleDrude = exp(-stepSize*integrator.getDrudeFriction());
308
    double fscaleDrude = (1-vscaleDrude)/integrator.getDrudeFriction()/(double) 0x100000000;
309
    double noisescaleDrude = sqrt(2*BOLTZ*integrator.getDrudeTemperature()*integrator.getDrudeFriction())*sqrt(0.5*(1-vscaleDrude*vscaleDrude)/integrator.getDrudeFriction());
310
311
    double maxDrudeDistance = integrator.getMaxDrudeDistance();
    double hardwallscaleDrude = sqrt(BOLTZ*integrator.getDrudeTemperature());
312
    if (stepSize != prevStepSize) {
313
        if (cc.getUseDoublePrecision() || cc.getUseMixedPrecision()) {
314
315
316
317
318
319
320
321
322
            mm_double2 ss = mm_double2(0, stepSize);
            integration.getStepSize().upload(&ss);
        }
        else {
            mm_float2 ss = mm_float2(0, (float) stepSize);
            integration.getStepSize().upload(&ss);
        }
        prevStepSize = stepSize;
    }
323
324
325
326
327
328
329
330
331
    if (cc.getUseDoublePrecision() || cc.getUseMixedPrecision()) {
            kernel1->setArg(6, vscale);
            kernel1->setArg(7, fscale);
            kernel1->setArg(8, noisescale);
            kernel1->setArg(9, vscaleDrude);
            kernel1->setArg(10, fscaleDrude);
            kernel1->setArg(11, noisescaleDrude);
            hardwallKernel->setArg(5, maxDrudeDistance);
            hardwallKernel->setArg(6, hardwallscaleDrude);
332
333
    }
    else {
334
335
336
337
338
339
340
341
            kernel1->setArg(6, (float) vscale);
            kernel1->setArg(7, (float) fscale);
            kernel1->setArg(8, (float) noisescale);
            kernel1->setArg(9, (float) vscaleDrude);
            kernel1->setArg(10, (float) fscaleDrude);
            kernel1->setArg(11, (float) noisescaleDrude);
            hardwallKernel->setArg(5, (float) maxDrudeDistance);
            hardwallKernel->setArg(6, (float) hardwallscaleDrude);
342
343
344
345
    }

    // Call the first integration kernel.

346
347
    kernel1->setArg(13, integration.prepareRandomNumbers(normalParticles.getSize()+2*pairParticles.getSize()));
    kernel1->execute(numAtoms);
348
349
350
351
352
353
354

    // Apply constraints.

    integration.applyConstraints(integrator.getConstraintTolerance());

    // Call the second integration kernel.

355
    kernel2->execute(numAtoms);
356
357
358
    
    // Apply hard wall constraints.
    
359
    if (maxDrudeDistance > 0)
360
        hardwallKernel->execute(pairParticles.getSize());
361
362
363
364
    integration.computeVirtualSites();

    // Update the time and step count.

365
366
367
    cc.setTime(cc.getTime()+stepSize);
    cc.setStepCount(cc.getStepCount()+1);
    cc.reorderAtoms();
368
369
}

370
371
double CommonIntegrateDrudeLangevinStepKernel::computeKineticEnergy(ContextImpl& context, const DrudeLangevinIntegrator& integrator) {
    return cc.getIntegrationUtilities().computeKineticEnergy(0.5*integrator.getStepSize());
372
}
373

374
CommonIntegrateDrudeSCFStepKernel::~CommonIntegrateDrudeSCFStepKernel() {
375
376
377
378
    if (minimizerPos != NULL)
        lbfgs_free(minimizerPos);
}

379
380
381
void CommonIntegrateDrudeSCFStepKernel::initialize(const System& system, const DrudeSCFIntegrator& integrator, const DrudeForce& force) {
    cc.initializeContexts();
    cc.setAsCurrent();
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401

    // Identify Drude particles.
    
    for (int i = 0; i < force.getNumParticles(); i++) {
        int p, p1, p2, p3, p4;
        double charge, polarizability, aniso12, aniso34;
        force.getParticleParameters(i, p, p1, p2, p3, p4, charge, polarizability, aniso12, aniso34);
        drudeParticles.push_back(p);
    }
    
    // Initialize the energy minimizer.
    
    minimizerPos = lbfgs_malloc(drudeParticles.size()*3);
    if (minimizerPos == NULL)
        throw OpenMMException("DrudeSCFIntegrator: Failed to allocate memory");
    lbfgs_parameter_init(&minimizerParams);
    minimizerParams.linesearch = LBFGS_LINESEARCH_BACKTRACKING_STRONG_WOLFE;    

    // Create the kernels.
    
402
403
404
    ComputeProgram program = cc.compileProgram(CommonKernelSources::verlet);
    kernel1 = program->createKernel("integrateVerletPart1");
    kernel2 = program->createKernel("integrateVerletPart2");
405
406
407
    prevStepSize = -1.0;
}

408
409
410
411
void CommonIntegrateDrudeSCFStepKernel::execute(ContextImpl& context, const DrudeSCFIntegrator& integrator) {
    cc.setAsCurrent();
    IntegrationUtilities& integration = cc.getIntegrationUtilities();
    int numAtoms = cc.getNumAtoms();
412
413
414
    double dt = integrator.getStepSize();
    if (!hasInitializedKernels) {
        hasInitializedKernels = true;
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
        kernel1->addArg(numAtoms);
        kernel1->addArg(cc.getPaddedNumAtoms());
        kernel1->addArg(cc.getIntegrationUtilities().getStepSize());
        kernel1->addArg(cc.getPosq());
        kernel1->addArg(cc.getVelm());
        kernel1->addArg(cc.getLongForceBuffer());
        kernel1->addArg(integration.getPosDelta());
        if (cc.getUseMixedPrecision())
            kernel1->addArg(cc.getPosqCorrection());
        kernel2->addArg(numAtoms);
        kernel2->addArg(cc.getIntegrationUtilities().getStepSize());
        kernel2->addArg(cc.getPosq());
        kernel2->addArg(cc.getVelm());
        kernel2->addArg(integration.getPosDelta());
        if (cc.getUseMixedPrecision())
            kernel2->addArg(cc.getPosqCorrection());
431
432
    }
    if (dt != prevStepSize) {
433
        if (cc.getUseDoublePrecision() || cc.getUseMixedPrecision()) {
434
435
            vector<mm_double2> stepSizeVec(1);
            stepSizeVec[0] = mm_double2(dt, dt);
436
            cc.getIntegrationUtilities().getStepSize().upload(stepSizeVec);
437
438
439
        }
        else {
            vector<mm_float2> stepSizeVec(1);
440
441
            stepSizeVec[0] = mm_float2((float) dt, (float) dt);
            cc.getIntegrationUtilities().getStepSize().upload(stepSizeVec);
442
443
444
445
446
447
        }
        prevStepSize = dt;
    }

    // Call the first integration kernel.

448
    kernel1->execute(numAtoms);
449
450
451
452
453
454
455

    // Apply constraints.

    integration.applyConstraints(integrator.getConstraintTolerance());

    // Call the second integration kernel.

456
    kernel2->execute(numAtoms);
457
458
459
460
461
462
463
464

    // Update the positions of virtual sites and Drude particles.

    integration.computeVirtualSites();
    minimize(context, integrator.getMinimizationErrorTolerance());

    // Update the time and step count.

465
466
467
    cc.setTime(cc.getTime()+dt);
    cc.setStepCount(cc.getStepCount()+1);
    cc.reorderAtoms();
468
469
470
471
    
    // Reduce UI lag.
    
#ifdef WIN32
472
    cc.flushQueue();
473
474
475
#endif
}

476
477
double CommonIntegrateDrudeSCFStepKernel::computeKineticEnergy(ContextImpl& context, const DrudeSCFIntegrator& integrator) {
    return cc.getIntegrationUtilities().computeKineticEnergy(0.5*integrator.getStepSize());
478
479
480
481
}

struct MinimizerData {
    ContextImpl& context;
482
    ComputeContext& cc;
483
    vector<int>& drudeParticles;
484
    MinimizerData(ContextImpl& context, ComputeContext& cc, vector<int>& drudeParticles) : context(context), cc(cc), drudeParticles(drudeParticles) {}
485
486
487
488
489
};

static lbfgsfloatval_t evaluate(void *instance, const lbfgsfloatval_t *x, lbfgsfloatval_t *g, const int n, const lbfgsfloatval_t step) {
    MinimizerData* data = reinterpret_cast<MinimizerData*>(instance);
    ContextImpl& context = data->context;
490
    ComputeContext& cc = data->cc;
491
492
    vector<int>& drudeParticles = data->drudeParticles;
    int numDrudeParticles = drudeParticles.size();
493
494
495

    // Set the particle positions.
    
496
497
498
    cc.getPosq().download(cc.getPinnedBuffer());
    if (cc.getUseDoublePrecision()) {
        mm_double4* posq = (mm_double4*) cc.getPinnedBuffer();
499
        for (int i = 0; i < numDrudeParticles; ++i) {
500
501
502
503
            mm_double4& p = posq[drudeParticles[i]];
            p.x = x[3*i];
            p.y = x[3*i+1];
            p.z = x[3*i+2];
504
505
506
        }
    }
    else {
507
        mm_float4* posq = (mm_float4*) cc.getPinnedBuffer();
508
        for (int i = 0; i < numDrudeParticles; ++i) {
509
510
511
512
            mm_float4& p = posq[drudeParticles[i]];
            p.x = x[3*i];
            p.y = x[3*i+1];
            p.z = x[3*i+2];
513
514
        }
    }
515
    cc.getPosq().upload(cc.getPinnedBuffer());
516
517
518
519

    // Compute the forces and energy for this configuration.

    double energy = context.calcForcesAndEnergy(true, true);
520
521
522
523
524
525
526
527
528
    long long* force = (long long*) cc.getPinnedBuffer();
    cc.getLongForceBuffer().download(force);
    double forceScale = -1.0/0x100000000;
    int paddedNumAtoms = cc.getPaddedNumAtoms();
    for (int i = 0; i < numDrudeParticles; ++i) {
        int index = drudeParticles[i];
        g[3*i] = forceScale*force[index];
        g[3*i+1] = forceScale*force[index+paddedNumAtoms];
        g[3*i+2] = forceScale*force[index+paddedNumAtoms*2];
529
530
531
532
    }
    return energy;
}

533
void CommonIntegrateDrudeSCFStepKernel::minimize(ContextImpl& context, double tolerance) {
534
535
    // Record the initial positions.

536
    int numDrudeParticles = drudeParticles.size();
537
538
539
    cc.getPosq().download(cc.getPinnedBuffer());
    if (cc.getUseDoublePrecision()) {
        mm_double4* posq = (mm_double4*) cc.getPinnedBuffer();
540
        for (int i = 0; i < numDrudeParticles; ++i) {
541
542
543
544
            mm_double4 p = posq[drudeParticles[i]];
            minimizerPos[3*i] = p.x;
            minimizerPos[3*i+1] = p.y;
            minimizerPos[3*i+2] = p.z;
545
546
547
        }
    }
    else {
548
        mm_float4* posq = (mm_float4*) cc.getPinnedBuffer();
549
        for (int i = 0; i < numDrudeParticles; ++i) {
550
551
552
553
            mm_float4 p = posq[drudeParticles[i]];
            minimizerPos[3*i] = p.x;
            minimizerPos[3*i+1] = p.y;
            minimizerPos[3*i+2] = p.z;
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
        }
        minimizerParams.xtol = 1e-7;
    }
    
    // Determine a normalization constant for scaling the tolerance.
    
    double norm = 0.0;
    for (int i = 0; i < 3*numDrudeParticles; i++)
        norm += minimizerPos[i]*minimizerPos[i];
    norm /= numDrudeParticles;
    norm = (norm < 1 ? 1 : sqrt(norm));
    minimizerParams.epsilon = tolerance/norm;
    
    // Perform the minimization.

    lbfgsfloatval_t fx;
570
    MinimizerData data(context, cc, drudeParticles);
571
572
    lbfgs(numDrudeParticles*3, minimizerPos, &fx, evaluate, NULL, &data, &minimizerParams);
}