OpenCLKernels.h 25 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
#ifndef OPENMM_OPENCLKERNELS_H_
#define OPENMM_OPENCLKERNELS_H_

/* -------------------------------------------------------------------------- *
 *                                   OpenMM                                   *
 * -------------------------------------------------------------------------- *
 * This is part of the OpenMM molecular simulation toolkit originating from   *
 * Simbios, the NIH National Center for Physics-Based Simulation of           *
 * Biological Structures at Stanford, funded under the NIH Roadmap for        *
 * Medical Research, grant U54 GM072970. See https://simtk.org.               *
 *                                                                            *
12
 * Portions copyright (c) 2008-2019 Stanford University and the Authors.      *
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
 * Authors: Peter Eastman                                                     *
 * Contributors:                                                              *
 *                                                                            *
 * This program is free software: you can redistribute it and/or modify       *
 * it under the terms of the GNU Lesser General Public License as published   *
 * by the Free Software Foundation, either version 3 of the License, or       *
 * (at your option) any later version.                                        *
 *                                                                            *
 * This program is distributed in the hope that it will be useful,            *
 * but WITHOUT ANY WARRANTY; without even the implied warranty of             *
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the              *
 * GNU Lesser General Public License for more details.                        *
 *                                                                            *
 * You should have received a copy of the GNU Lesser General Public License   *
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.      *
 * -------------------------------------------------------------------------- */

#include "OpenCLPlatform.h"
31
32
#include "OpenCLArray.h"
#include "OpenCLContext.h"
33
#include "OpenCLFFT3D.h"
34
#include "OpenCLParameterSet.h"
35
#include "OpenCLSort.h"
36
#include "openmm/kernels.h"
37
38
39
#include "openmm/internal/CompiledExpressionSet.h"
#include "openmm/internal/CustomIntegratorUtilities.h"
#include "lepton/CompiledExpression.h"
40
#include "lepton/ExpressionProgram.h"
41
42
43
44
#include "openmm/System.h"

namespace OpenMM {

45
/**
46
47
48
 * This kernel is invoked at the beginning and end of force and energy computations.  It gives the
 * Platform a chance to clear buffers and do other initialization at the beginning, and to do any
 * necessary work at the end to determine the final results.
49
 */
50
class OpenCLCalcForcesAndEnergyKernel : public CalcForcesAndEnergyKernel {
51
public:
52
    OpenCLCalcForcesAndEnergyKernel(std::string name, const Platform& platform, OpenCLContext& cl) : CalcForcesAndEnergyKernel(name, platform), cl(cl) {
53
54
55
56
57
58
59
60
    }
    /**
     * Initialize the kernel.
     *
     * @param system     the System this kernel will be applied to
     */
    void initialize(const System& system);
    /**
61
     * This is called at the beginning of each force/energy computation, before calcForcesAndEnergy() has been called on
62
     * any ForceImpl.
63
     *
64
65
66
     * @param context       the context in which to execute this kernel
     * @param includeForce  true if forces should be computed
     * @param includeEnergy true if potential energy should be computed
67
     * @param groups        a set of bit flags for which force groups to include
68
     */
69
    void beginComputation(ContextImpl& context, bool includeForce, bool includeEnergy, int groups);
70
    /**
71
     * This is called at the end of each force/energy computation, after calcForcesAndEnergy() has been called on
72
73
     * every ForceImpl.
     *
74
75
76
     * @param context       the context in which to execute this kernel
     * @param includeForce  true if forces should be computed
     * @param includeEnergy true if potential energy should be computed
77
     * @param groups        a set of bit flags for which force groups to include
78
79
     * @param valid         the method may set this to false to indicate the results are invalid and the force/energy
     *                      calculation should be repeated
80
     * @return the potential energy of the system.  This value is added to all values returned by ForceImpls'
81
     * calcForcesAndEnergy() methods.  That is, each force kernel may <i>either</i> return its contribution to the
82
83
     * energy directly, <i>or</i> add it to an internal buffer so that it will be included here.
     */
84
    double finishComputation(ContextImpl& context, bool includeForce, bool includeEnergy, int groups, bool& valid);
85
private:
86
   OpenCLContext& cl;
87
88
89
};

/**
90
91
 * This kernel provides methods for setting and retrieving various state data: time, positions,
 * velocities, and forces.
92
 */
93
class OpenCLUpdateStateDataKernel : public UpdateStateDataKernel {
94
public:
95
    OpenCLUpdateStateDataKernel(std::string name, const Platform& platform, OpenCLContext& cl) : UpdateStateDataKernel(name, platform), cl(cl) {
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    }
    /**
     * Initialize the kernel.
     *
     * @param system     the System this kernel will be applied to
     */
    void initialize(const System& system);
    /**
     * Get the current time (in picoseconds).
     *
     * @param context    the context in which to execute this kernel
     */
    double getTime(const ContextImpl& context) const;
    /**
     * Set the current time (in picoseconds).
     *
     * @param context    the context in which to execute this kernel
     */
    void setTime(ContextImpl& context, double time);
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    /**
     * Get the positions of all particles.
     *
     * @param positions  on exit, this contains the particle positions
     */
    void getPositions(ContextImpl& context, std::vector<Vec3>& positions);
    /**
     * Set the positions of all particles.
     *
     * @param positions  a vector containg the particle positions
     */
    void setPositions(ContextImpl& context, const std::vector<Vec3>& positions);
    /**
     * Get the velocities of all particles.
     *
     * @param velocities  on exit, this contains the particle velocities
     */
    void getVelocities(ContextImpl& context, std::vector<Vec3>& velocities);
    /**
     * Set the velocities of all particles.
     *
     * @param velocities  a vector containg the particle velocities
     */
    void setVelocities(ContextImpl& context, const std::vector<Vec3>& velocities);
    /**
     * Get the current forces on all particles.
     *
     * @param forces  on exit, this contains the forces
     */
    void getForces(ContextImpl& context, std::vector<Vec3>& forces);
145
146
147
148
149
150
    /**
     * Get the current derivatives of the energy with respect to context parameters.
     *
     * @param derivs  on exit, this contains the derivatives
     */
    void getEnergyParameterDerivatives(ContextImpl& context, std::map<std::string, double>& derivs);
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
    /**
     * Get the current periodic box vectors.
     *
     * @param a      on exit, this contains the vector defining the first edge of the periodic box
     * @param b      on exit, this contains the vector defining the second edge of the periodic box
     * @param c      on exit, this contains the vector defining the third edge of the periodic box
     */
    void getPeriodicBoxVectors(ContextImpl& context, Vec3& a, Vec3& b, Vec3& c) const;
    /**
     * Set the current periodic box vectors.
     *
     * @param a      the vector defining the first edge of the periodic box
     * @param b      the vector defining the second edge of the periodic box
     * @param c      the vector defining the third edge of the periodic box
     */
166
    void setPeriodicBoxVectors(ContextImpl& context, const Vec3& a, const Vec3& b, const Vec3& c);
Peter Eastman's avatar
Peter Eastman committed
167
168
169
170
171
172
173
174
175
176
177
178
    /**
     * Create a checkpoint recording the current state of the Context.
     * 
     * @param stream    an output stream the checkpoint data should be written to
     */
    void createCheckpoint(ContextImpl& context, std::ostream& stream);
    /**
     * Load a checkpoint that was written by createCheckpoint().
     * 
     * @param stream    an input stream the checkpoint data should be read from
     */
    void loadCheckpoint(ContextImpl& context, std::istream& stream);
179
private:
180
    OpenCLContext& cl;
181
};
182

183
184
185
186
187
/**
 * This kernel modifies the positions of particles to enforce distance constraints.
 */
class OpenCLApplyConstraintsKernel : public ApplyConstraintsKernel {
public:
188
189
    OpenCLApplyConstraintsKernel(std::string name, const Platform& platform, OpenCLContext& cl) : ApplyConstraintsKernel(name, platform),
            cl(cl), hasInitializedKernel(false) {
190
191
192
193
194
195
196
197
198
199
200
201
202
203
    }
    /**
     * Initialize the kernel.
     *
     * @param system     the System this kernel will be applied to
     */
    void initialize(const System& system);
    /**
     * Update particle positions to enforce constraints.
     *
     * @param context    the context in which to execute this kernel
     * @param tol        the distance tolerance within which constraints must be satisfied.
     */
    void apply(ContextImpl& context, double tol);
204
205
206
207
208
209
210
    /**
     * Update particle velocities to enforce constraints.
     *
     * @param context    the context in which to execute this kernel
     * @param tol        the velocity tolerance within which constraints must be satisfied.
     */
    void applyToVelocities(ContextImpl& context, double tol);
211
212
private:
    OpenCLContext& cl;
213
214
    bool hasInitializedKernel;
    cl::Kernel applyDeltasKernel;
215
216
};

217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
/**
 * This kernel recomputes the positions of virtual sites.
 */
class OpenCLVirtualSitesKernel : public VirtualSitesKernel {
public:
    OpenCLVirtualSitesKernel(std::string name, const Platform& platform, OpenCLContext& cl) : VirtualSitesKernel(name, platform), cl(cl) {
    }
    /**
     * Initialize the kernel.
     *
     * @param system     the System this kernel will be applied to
     */
    void initialize(const System& system);
    /**
     * Compute the virtual site locations.
     *
     * @param context    the context in which to execute this kernel
     */
    void computePositions(ContextImpl& context);
private:
    OpenCLContext& cl;
};

240
241
242
243
244
/**
 * This kernel is invoked by NonbondedForce to calculate the forces acting on the system.
 */
class OpenCLCalcNonbondedForceKernel : public CalcNonbondedForceKernel {
public:
245
    OpenCLCalcNonbondedForceKernel(std::string name, const Platform& platform, OpenCLContext& cl, const System& system) : CalcNonbondedForceKernel(name, platform),
246
            hasInitializedKernel(false), cl(cl), sort(NULL), fft(NULL), dispersionFft(NULL), pmeio(NULL), usePmeQueue(false) {
247
248
249
250
251
252
253
254
255
256
    }
    ~OpenCLCalcNonbondedForceKernel();
    /**
     * Initialize the kernel.
     *
     * @param system     the System this kernel will be applied to
     * @param force      the NonbondedForce this kernel will be used for
     */
    void initialize(const System& system, const NonbondedForce& force);
    /**
257
     * Execute the kernel to calculate the forces and/or energy.
258
     *
259
260
261
     * @param context        the context in which to execute this kernel
     * @param includeForces  true if forces should be calculated
     * @param includeEnergy  true if the energy should be calculated
262
263
     * @param includeDirect  true if direct space interactions should be included
     * @param includeReciprocal  true if reciprocal space interactions should be included
264
     * @return the potential energy due to the force
265
     */
266
    double execute(ContextImpl& context, bool includeForces, bool includeEnergy, bool includeDirect, bool includeReciprocal);
267
268
269
270
271
272
273
    /**
     * Copy changed parameters over to a context.
     *
     * @param context    the context to copy parameters to
     * @param force      the NonbondedForce to copy the parameters from
     */
    void copyParametersToContext(ContextImpl& context, const NonbondedForce& force);
274
275
    /**
     * Get the parameters being used for PME.
276
     *
277
278
279
280
281
282
     * @param alpha   the separation parameter
     * @param nx      the number of grid points along the X axis
     * @param ny      the number of grid points along the Y axis
     * @param nz      the number of grid points along the Z axis
     */
    void getPMEParameters(double& alpha, int& nx, int& ny, int& nz) const;
283
284
285
286
287
288
289
290
    /**
     * Get the parameters being used for the dispersion term in LJPME.
     *
     * @param alpha   the separation parameter
     * @param nx      the number of grid points along the X axis
     * @param ny      the number of grid points along the Y axis
     * @param nz      the number of grid points along the Z axis
     */
291
    void getLJPMEParameters(double& alpha, int& nx, int& ny, int& nz) const;
292
private:
293
294
295
296
297
298
299
300
301
    class SortTrait : public OpenCLSort::SortTrait {
        int getDataSize() const {return 8;}
        int getKeySize() const {return 4;}
        const char* getDataType() const {return "int2";}
        const char* getKeyType() const {return "int";}
        const char* getMinKey() const {return "INT_MIN";}
        const char* getMaxKey() const {return "INT_MAX";}
        const char* getMaxValue() const {return "(int2) (INT_MAX, INT_MAX)";}
        const char* getSortKey() const {return "value.y";}
302
    };
303
    class ForceInfo;
304
305
306
    class PmeIO;
    class PmePreComputation;
    class PmePostComputation;
307
308
    class SyncQueuePreComputation;
    class SyncQueuePostComputation;
309
    OpenCLContext& cl;
310
    ForceInfo* info;
311
    bool hasInitializedKernel;
312
    OpenCLArray charges;
peastman's avatar
peastman committed
313
314
    OpenCLArray sigmaEpsilon;
    OpenCLArray exceptionParams;
315
316
    OpenCLArray exclusionAtoms;
    OpenCLArray exclusionParams;
317
318
319
320
321
322
323
    OpenCLArray baseParticleParams;
    OpenCLArray baseExceptionParams;
    OpenCLArray particleParamOffsets;
    OpenCLArray exceptionParamOffsets;
    OpenCLArray particleOffsetIndices;
    OpenCLArray exceptionOffsetIndices;
    OpenCLArray globalParams;
peastman's avatar
peastman committed
324
    OpenCLArray cosSinSums;
Peter Eastman's avatar
Peter Eastman committed
325
    OpenCLArray pmeGrid1;
peastman's avatar
peastman committed
326
327
328
329
330
331
332
333
334
335
336
    OpenCLArray pmeGrid2;
    OpenCLArray pmeBsplineModuliX;
    OpenCLArray pmeBsplineModuliY;
    OpenCLArray pmeBsplineModuliZ;
    OpenCLArray pmeDispersionBsplineModuliX;
    OpenCLArray pmeDispersionBsplineModuliY;
    OpenCLArray pmeDispersionBsplineModuliZ;
    OpenCLArray pmeBsplineTheta;
    OpenCLArray pmeAtomRange;
    OpenCLArray pmeAtomGridIndex;
    OpenCLArray pmeEnergyBuffer;
337
    OpenCLSort* sort;
338
339
    cl::CommandQueue pmeQueue;
    cl::Event pmeSyncEvent;
340
    OpenCLFFT3D* fft;
341
    OpenCLFFT3D* dispersionFft;
342
343
    Kernel cpuPme;
    PmeIO* pmeio;
344
    SyncQueuePostComputation* syncQueue;
Peter Eastman's avatar
Peter Eastman committed
345
    cl::Kernel computeParamsKernel, computeExclusionParamsKernel;
346
347
    cl::Kernel ewaldSumsKernel;
    cl::Kernel ewaldForcesKernel;
348
    cl::Kernel pmeAtomRangeKernel;
349
    cl::Kernel pmeDispersionAtomRangeKernel;
350
    cl::Kernel pmeZIndexKernel;
351
    cl::Kernel pmeDispersionZIndexKernel;
352
    cl::Kernel pmeUpdateBsplinesKernel;
353
    cl::Kernel pmeDispersionUpdateBsplinesKernel;
354
    cl::Kernel pmeSpreadChargeKernel;
355
    cl::Kernel pmeDispersionSpreadChargeKernel;
356
    cl::Kernel pmeFinishSpreadChargeKernel;
357
    cl::Kernel pmeDispersionFinishSpreadChargeKernel;
358
    cl::Kernel pmeConvolutionKernel;
359
    cl::Kernel pmeDispersionConvolutionKernel;
360
    cl::Kernel pmeEvalEnergyKernel;
361
    cl::Kernel pmeDispersionEvalEnergyKernel;
362
    cl::Kernel pmeInterpolateForceKernel;
363
    cl::Kernel pmeDispersionInterpolateForceKernel;
364
    std::map<std::string, std::string> pmeDefines;
365
    std::vector<std::pair<int, int> > exceptionAtoms;
366
367
    std::vector<std::string> paramNames;
    std::vector<double> paramValues;
368
    double ewaldSelfEnergy, dispersionCoefficient, alpha, dispersionAlpha;
369
    int gridSizeX, gridSizeY, gridSizeZ;
370
    int dispersionGridSizeX, dispersionGridSizeY, dispersionGridSizeZ;
Peter Eastman's avatar
Peter Eastman committed
371
    bool hasCoulomb, hasLJ, usePmeQueue, doLJPME, usePosqCharges, recomputeParams, hasOffsets;
372
    NonbondedMethod nonbondedMethod;
373
    static const int PmeOrder = 5;
374
375
};

376
/**
377
 * This kernel is invoked by CustomCVForce to calculate the forces acting on the system and the energy of the system.
378
 */
379
class OpenCLCalcCustomCVForceKernel : public CalcCustomCVForceKernel {
380
public:
381
382
    OpenCLCalcCustomCVForceKernel(std::string name, const Platform& platform, OpenCLContext& cl) : CalcCustomCVForceKernel(name, platform),
            cl(cl), hasInitializedKernels(false) {
383
384
385
386
387
    }
    /**
     * Initialize the kernel.
     *
     * @param system     the System this kernel will be applied to
388
389
     * @param force      the CustomCVForce this kernel will be used for
     * @param innerContext   the context created by the CustomCVForce for computing collective variables
390
     */
391
    void initialize(const System& system, const CustomCVForce& force, ContextImpl& innerContext);
392
    /**
393
     * Execute the kernel to calculate the forces and/or energy.
394
     *
395
     * @param context        the context in which to execute this kernel
396
     * @param innerContext   the context created by the CustomCVForce for computing collective variables
397
398
399
     * @param includeForces  true if forces should be calculated
     * @param includeEnergy  true if the energy should be calculated
     * @return the potential energy due to the force
400
     */
401
402
403
404
405
406
407
408
    double execute(ContextImpl& context, ContextImpl& innerContext, bool includeForces, bool includeEnergy);
    /**
     * Copy state information to the inner context.
     *
     * @param context        the context in which to execute this kernel
     * @param innerContext   the context created by the CustomCVForce for computing collective variables
     */
    void copyState(ContextImpl& context, ContextImpl& innerContext);
409
410
411
412
    /**
     * Copy changed parameters over to a context.
     *
     * @param context    the context to copy parameters to
413
     * @param force      the CustomCVForce to copy the parameters from
414
     */
415
    void copyParametersToContext(ContextImpl& context, const CustomCVForce& force);
416
private:
417
    class ForceInfo;
418
    class ReorderListener;
419
    OpenCLContext& cl;
420
    bool hasInitializedKernels;
421
422
423
424
425
426
427
428
    Lepton::ExpressionProgram energyExpression;
    std::vector<std::string> variableNames, paramDerivNames, globalParameterNames;
    std::vector<Lepton::ExpressionProgram> variableDerivExpressions;
    std::vector<Lepton::ExpressionProgram> paramDerivExpressions;
    std::vector<OpenCLArray> cvForces;
    OpenCLArray invAtomOrder;
    OpenCLArray innerInvAtomOrder;
    cl::Kernel copyStateKernel, copyForcesKernel, addForcesKernel;
429
430
};

431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
/*
 * This kernel is invoked by NoseHooverIntegrator to take one time step.
 */
class OpenCLIntegrateVelocityVerletStepKernel : public IntegrateVelocityVerletStepKernel {
public:
    OpenCLIntegrateVelocityVerletStepKernel(std::string name, const Platform& platform, OpenCLContext& cl) :
                                  IntegrateVelocityVerletStepKernel(name, platform), cl(cl) { }
    ~OpenCLIntegrateVelocityVerletStepKernel() {}
    /**
     * Initialize the kernel.
     * 
     * @param system     the System this kernel will be applied to
     * @param integrator the NoseHooverIntegrator this kernel will be used for
     */
    void initialize(const System& system, const NoseHooverIntegrator& integrator);
    /**
     * Execute the kernel.
     * 
     * @param context    the context in which to execute this kernel
     * @param integrator the VerletIntegrator this kernel is being used for
     * @param forcesAreValid a reference to the parent integrator's boolean for keeping
     *                       track of the validity of the current forces.
     */
    void execute(ContextImpl& context, const NoseHooverIntegrator& integrator, bool &forcesAreValid);
    /**
     * Compute the kinetic energy.
     * 
     * @param context    the context in which to execute this kernel
     * @param integrator the NoseHooverIntegrator this kernel is being used for
     */
    double computeKineticEnergy(ContextImpl& context, const NoseHooverIntegrator& integrator);
private:
    OpenCLContext& cl;
464
465
466
    float prevMaxPairDistance;
    OpenCLArray maxPairDistanceBuffer, pairListBuffer, atomListBuffer, pairTemperatureBuffer; 
    cl::Kernel kernel1, kernel2, kernel3, kernelHardWall;
467
468
469
470
471
472
473
474
475
476
477
478
479
480
};

/**
 * This kernel is invoked by NoseHooverChain at the start of each time step to adjust the thermostat
 * and update the associated particle velocities.
 */
class OpenCLNoseHooverChainKernel : public NoseHooverChainKernel {
public:
    OpenCLNoseHooverChainKernel(std::string name, const Platform& platform, OpenCLContext& cl) : NoseHooverChainKernel(name, platform), cl(cl) {
    }
    ~OpenCLNoseHooverChainKernel() {}
    /**
     * Initialize the kernel.
     */
481
    void initialize();
482
483
484
485
486
487
488
489
490
    /**
     * Execute the kernel that propagates the Nose Hoover chain and determines the velocity scale factor.
     * 
     * @param context  the context in which to execute this kernel
     * @param noseHooverChain the object describing the chain to be propagated.
     * @param kineticEnergies the {absolute, relative} kineticEnergy of the particles being thermostated by this chain.
     * @param timeStep the time step used by the integrator.
     * @return the {absolute, relative} velocity scale factor to apply to the particles associated with this heat bath.
     */
491
    std::pair<double, double> propagateChain(ContextImpl& context, const NoseHooverChain &nhc, std::pair<double, double> kineticEnergies, double timeStep);
492
493
494
495
496
497
498
    /**
     * Execute the kernal that computes the total (kinetic + potential) heat bath energy.
     *
     * @param context the context in which to execute this kernel
     * @param noseHooverChain the chain whose energy is to be determined.
     * @return the total heat bath energy.
     */
499
    double computeHeatBathEnergy(ContextImpl& context, const NoseHooverChain &nhc);
500
501
502
503
504
505
506
507
508
    /**
     * Execute the kernel that computes the kinetic energy for a subset of atoms,
     * or the relative kinetic energy of Drude particles with respect to their parent atoms
     *
     * @param context the context in which to execute this kernel
     * @param noseHooverChain the chain whose energy is to be determined.
     * @param downloadValue whether the computed value should be downloaded and returned.
     *
     */
509
    std::pair<double,double> computeMaskedKineticEnergy(ContextImpl& context, const NoseHooverChain &noseHooverChain, bool downloadValue);
510
511
512
513
514
515
516
517

    /**
     * Execute the kernel that scales the velocities of particles associated with a nose hoover chain
     *
     * @param context the context in which to execute this kernel
     * @param noseHooverChain the chain whose energy is to be determined.
     * @param scaleFactors the {absolute, relative} multiplicative factor by which velocities are scaled.
     */
518
    void scaleVelocities(ContextImpl& context, const NoseHooverChain &noseHooverChain, std::pair<double, double> scaleFactors);
519
520
521
522

private:
    int sumWorkGroupSize;
    OpenCLContext& cl;
523
    OpenCLArray energyBuffer, scaleFactorBuffer, kineticEnergyBuffer, chainMasses, chainForces, heatBathEnergy;
524
525
526
527
528
529
530
531
532
533
    std::map<int, OpenCLArray> atomlists, pairlists;
    std::map<int, cl::Kernel> propagateKernels;
    cl::Kernel reduceEnergyKernel;
    cl::Kernel computeHeatBathEnergyKernel;
    cl::Kernel computeAtomsKineticEnergyKernel;
    cl::Kernel computePairsKineticEnergyKernel;
    cl::Kernel scaleAtomsVelocitiesKernel;
    cl::Kernel scalePairsVelocitiesKernel;
};

534
535
536
537
538
539
/**
 * This kernel is invoked by MonteCarloBarostat to adjust the periodic box volume
 */
class OpenCLApplyMonteCarloBarostatKernel : public ApplyMonteCarloBarostatKernel {
public:
    OpenCLApplyMonteCarloBarostatKernel(std::string name, const Platform& platform, OpenCLContext& cl) : ApplyMonteCarloBarostatKernel(name, platform), cl(cl),
peastman's avatar
peastman committed
540
            hasInitializedKernels(false) {
541
542
543
544
545
546
547
    }
    /**
     * Initialize the kernel.
     *
     * @param system     the System this kernel will be applied to
     * @param barostat   the MonteCarloBarostat this kernel will be used for
     */
548
    void initialize(const System& system, const Force& barostat);
549
550
551
552
553
554
555
556
557
558
559
560
    /**
     * Attempt a Monte Carlo step, scaling particle positions (or cluster centers) by a specified value.
     * This version scales the x, y, and z positions independently.
     * This is called BEFORE the periodic box size is modified.  It should begin by translating each particle
     * or cluster into the first periodic box, so that coordinates will still be correct after the box size
     * is changed.
     *
     * @param context    the context in which to execute this kernel
     * @param scaleX     the scale factor by which to multiply particle x-coordinate
     * @param scaleY     the scale factor by which to multiply particle y-coordinate
     * @param scaleZ     the scale factor by which to multiply particle z-coordinate
     */
561
    void scaleCoordinates(ContextImpl& context, double scaleX, double scaleY, double scaleZ);
562
563
564
565
566
567
568
569
570
571
572
    /**
     * Reject the most recent Monte Carlo step, restoring the particle positions to where they were before
     * scaleCoordinates() was last called.
     *
     * @param context    the context in which to execute this kernel
     */
    void restoreCoordinates(ContextImpl& context);
private:
    OpenCLContext& cl;
    bool hasInitializedKernels;
    int numMolecules;
peastman's avatar
peastman committed
573
574
575
576
    OpenCLArray savedPositions;
    OpenCLArray savedForces;
    OpenCLArray moleculeAtoms;
    OpenCLArray moleculeStartIndex;
577
    cl::Kernel kernel;
578
    std::vector<int> lastAtomOrder;
579
};
580

581
582
583
} // namespace OpenMM

#endif /*OPENMM_OPENCLKERNELS_H_*/