Unverified Commit d8c67699 authored by Emilio Gallicchio's avatar Emilio Gallicchio Committed by GitHub
Browse files

Draft integration of the Alchemical Transfer Method (ATM) plugin (#4110)



* Draft integration of the Alchemical Transfer Method (ATM) plugin

* Attempt to store and retrieve forces--does not compile

* Implement addForce()/getForce() methods

* Throw exception when specifying properties without a Platform (#4130)

* Fixed DOF calculation for NoseHooverIntegrator (#4128)

* Fix variance in documentation of VerletIntegrator (#4138)

* Python API for ATMForce

* Fixed compilation error

* Minor cleanup of formatting and documentation

* Files for ATMForce test cases

* More cleanup

* Removed variable groups

* Test ATMForce with two particles

* More tests for ATMForce plus fixes

* Added missing header

* Rework interface to pass displacements as vector of parameters

* Revert "Rework interface to pass displacements as vector of parameters"

This reverts commit 5e092031f31ded1137b677588f007add1c2d6f82.

* Test with nonbonded force

* Allow energy expression to be customized

* Optional displacements at the initial state

* Fixed compilation error build C wrapper

* Address edge case of default energy expression

* Consistent naming of the variables of the displacement states

* Test of soft core function of the default energy expression

* Mark addForce() as taking ownership

* initial python test for ATMForce

* Test custom expressions

* Expanded C++ API documentation for ATMForce

* Energy parameter derivatives

* Serialization for ATMForce

* Documentation, cleanup, and fixes

* Fixed typos

* getPerturbationEnergy() computes energy

* Another test case

* Minor edits

---------
Co-authored-by: default avatarPeter Eastman <peastman@stanford.edu>
Co-authored-by: default avatarMichael Plainer <plainer@ymail.com>
parent 889baef6
...@@ -64,6 +64,7 @@ be combined in arbitrary ways. ...@@ -64,6 +64,7 @@ be combined in arbitrary ways.
:maxdepth: 2 :maxdepth: 2
generated/AndersenThermostat generated/AndersenThermostat
generated/ATMForce
generated/CMMotionRemover generated/CMMotionRemover
generated/MonteCarloAnisotropicBarostat generated/MonteCarloAnisotropicBarostat
generated/MonteCarloBarostat generated/MonteCarloBarostat
......
...@@ -29,6 +29,18 @@ ...@@ -29,6 +29,18 @@
type = {Journal Article} type = {Journal Article}
} }
@article{Azimi2022,
author = {Azimi, Solmaz and Khuttan, Sheenam and Wu, Joe Z. and Pal, Rajat K. and Gallicchio, Emilio},
title = {Relative Binding Free Energy Calculations for Ligands with Diverse Scaffolds with the Alchemical Transfer Method},
journal = {Journal of Chemical Information and Modeling},
volume = {62},
number = {2},
pages = {309-323},
year = {2022},
type = {Journal Article},
doi = {10.1021/acs.jcim.1c01129}
}
@article{Barducci2008, @article{Barducci2008,
title = {Well-Tempered Metadynamics: A Smoothly Converging and Tunable Free-Energy Method}, title = {Well-Tempered Metadynamics: A Smoothly Converging and Tunable Free-Energy Method},
author = {Barducci, Alessandro and Bussi, Giovanni and Parrinello, Michele}, author = {Barducci, Alessandro and Bussi, Giovanni and Parrinello, Michele},
......
...@@ -413,6 +413,14 @@ be used as a collective variable. The energy is then computed as ...@@ -413,6 +413,14 @@ be used as a collective variable. The energy is then computed as
where *f*\ (...) is a user supplied mathematical expression of the collective where *f*\ (...) is a user supplied mathematical expression of the collective
variables. It also may depend on user defined global parameters. variables. It also may depend on user defined global parameters.
ATMForce
********
ATMForce implements the Alchemical Transfer Method for free energy calculations.\ :cite:`Azimi2022`
It contains one or more :code:`Force` objects whose energy is evaluated twice,
before and after displacing some particles to new positions. The final energy
is determined by a user supplied mathematical function of the two energies. See
the API documentation and the publication for more details.
.. _writing-custom-expressions: .. _writing-custom-expressions:
......
...@@ -17,7 +17,7 @@ SET(OpenMM_CWRAPPER "OpenMMCWrapper") ...@@ -17,7 +17,7 @@ SET(OpenMM_CWRAPPER "OpenMMCWrapper")
SET(OpenMM_FWRAPPER "OpenMMFortranWrapper") SET(OpenMM_FWRAPPER "OpenMMFortranWrapper")
SET(OpenMM_FMODULE "OpenMMFortranModule") SET(OpenMM_FMODULE "OpenMMFortranModule")
SET(CPP_EXAMPLES HelloArgon HelloSodiumChloride HelloEthane HelloWaterBox) SET(CPP_EXAMPLES HelloArgon HelloSodiumChloride HelloEthane HelloWaterBox )
SET(C_EXAMPLES HelloArgonInC HelloSodiumChlorideInC) SET(C_EXAMPLES HelloArgonInC HelloSodiumChlorideInC)
SET(F_EXAMPLES HelloArgonInFortran HelloSodiumChlorideInFortran) SET(F_EXAMPLES HelloArgonInFortran HelloSodiumChlorideInFortran)
......
...@@ -66,6 +66,7 @@ ...@@ -66,6 +66,7 @@
#include "openmm/VerletIntegrator.h" #include "openmm/VerletIntegrator.h"
#include "openmm/NoseHooverIntegrator.h" #include "openmm/NoseHooverIntegrator.h"
#include "openmm/NoseHooverChain.h" #include "openmm/NoseHooverChain.h"
#include "openmm/ATMForce.h"
#include <iosfwd> #include <iosfwd>
#include <set> #include <set>
#include <string> #include <string>
...@@ -1637,6 +1638,55 @@ public: ...@@ -1637,6 +1638,55 @@ public:
virtual void getPMEParameters(double& alpha, int& nx, int& ny, int& nz) const = 0; virtual void getPMEParameters(double& alpha, int& nx, int& ny, int& nz) const = 0;
}; };
/**
* This kernel is invoked by ATMForce to calculate the forces acting on the system and the energy of the system.
*/
class CalcATMForceKernel : public KernelImpl {
public:
static std::string Name() {
return "CalcATMForce";
}
CalcATMForceKernel(std::string name, const Platform& platform) : KernelImpl(name, platform) {
}
/**
* Initialize the kernel.
*
* @param system the System this kernel will be applied to
* @param force the ATMForce this kernel will be used for
*/
virtual void initialize(const System& system, const ATMForce& force) = 0;
/**
* Scale the forces from the inner contexts and apply them to the main context.
*
* @param context the context in which to execute this kernel
* @param innerContext0 the first inner context
* @param innerContext1 the second inner context
* @param dEdu0 the derivative of the final energy with respect to the first inner context's energy
* @param dEdu1 the derivative of the final energy with respect to the second inner context's energy
* @param energyParamDerivs derivatives of the final energy with respect to global parameters
*/
virtual void applyForces(ContextImpl& context, ContextImpl& innerContext0, ContextImpl& innerContext1,
double dEdu0, double dEdu1, const std::map<std::string, double>& energyParamDerivs) = 0;
/**
* Copy changed parameters over to a context.
*
* @param context the context to copy parameters to
* @param force the ATMForce to copy the parameters from
*/
virtual void copyParametersToContext(ContextImpl& context, const ATMForce& force) = 0;
/**
* Copy state information to the inner contexts.
*
* @param context the context in which to execute this kernel
* @param innerContext0 the first context created by the ATMForce for computing displaced energy
* @param innerContext1 the second context created by the ATMForce for computing displaced energy
*/
virtual void copyState(ContextImpl& context, ContextImpl& innerContext0, ContextImpl& innerContext1) = 0;
};
} // namespace OpenMM } // namespace OpenMM
#endif /*OPENMM_KERNELS_H_*/ #endif /*OPENMM_KERNELS_H_*/
...@@ -81,5 +81,6 @@ ...@@ -81,5 +81,6 @@
#include "openmm/VirtualSite.h" #include "openmm/VirtualSite.h"
#include "openmm/Platform.h" #include "openmm/Platform.h"
#include "openmm/serialization/XmlSerializer.h" #include "openmm/serialization/XmlSerializer.h"
#include "openmm/ATMForce.h"
#endif /*OPENMM_H_*/ #endif /*OPENMM_H_*/
#ifndef OPENMM_ATMFORCE_H_
#define OPENMM_ATMFORCE_H_
/* -------------------------------------------------------------------------- *
* OpenMM's Alchemical Transfer Force *
* -------------------------------------------------------------------------- *
* This is a Force of the OpenMM molecular simulation toolkit *
* that implements the Alchemical Transfer Potential *
* for absolute and relative binding free energy estimation *
* (https://doi.org/10.1021/acs.jcim.1c01129). The code is derived from the *
* ATMMetaForce plugin *
* https://github.com/Gallicchio-Lab/openmm-atmmetaforce-plugin *
* with support from the National Science Foundation CAREER 1750511 *
* *
* Portions copyright (c) 2021-2023 by the Authors *
* Authors: Emilio Gallicchio *
* Contributors: Peter Eastman *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#include "Force.h"
#include "internal/AssertionUtilities.h"
#include <openmm/Vec3.h>
#include <vector>
#include <string>
#include "internal/windowsExport.h"
namespace OpenMM {
/**
*
* The ATMForce class implements the Alchemical Transfer Method (ATM) for OpenMM.
* ATM is used to compute the binding free energies of molecular complexes and of other equilibrium processes.
* ATM and its implementation are described in the open access article:
*
* Solmaz Azimi, Sheenam Khuttan, Joe Z. Wu, Rajat K. Pal, and Emilio Gallicchio.
* Relative Binding Free Energy Calculations for Ligands with Diverse Scaffolds with the Alchemical Transfer Method.
* J. Chem. Inf. Model. 62, 309 (2022)
* https://doi.org/10.1021/acs.jcim.1c01129
*
* Refer to the publication above for a detailed description of the ATM method and the parameters used in this API
* and please cite it to support our work if you use this software in your research.
*
* The ATMForce implements an arbitrary potential energy function that depends on the potential
* energies (u0 and u1) of the system before and after a set of atoms are displaced by a specified amount.
* For example, you might displace a molecule from the solvent bulk to a receptor binding site to simulate
* a binding process. The potential energy function typically also depends on one or more parameters that
* are dialed to implement alchemical transformations.
*
* To use this class, create an ATMForce object, passing an algebraic expression to the
* constructor that defines the potential energy. This expression can be any combination
* of the variables u0 and u1. Then call addGlobalParameter() to define the parameters on which the potential energy expression depends.
* The values of global parameters may be modified during a simulation by calling Context::setParameter().
* Next, call addForce() to add Force objects that define the terms of the potential energy function
* that change upon displacement. Finally, call addParticle() to specify the displacement applied to
* each particle. Displacements can be changed by calling setParticleParameters(). As any per-particle parameters,
* changes in displacements take effect only after calling updateParametersInContext().
*
* As an example, the following code creates an ATMForce based on the change in energy of
* two particles when the second particle is displaced by 1 nm in the x direction.
* The energy change is dialed using an alchemical parameter Lambda, which in this case is set to 1/2:
*
* \verbatim embed:rst:leading-asterisk
* .. code-block:: cpp
*
* ATMForce *atmforce = new ATMForce("u0 + Lambda*(u1 - u0)");
* atm->addGlobalParameter("Lambda", 0.5);
* atm->addParticle(Vec3(0, 0, 0));
* atm->addParticle(Vec3(1, 0, 0));
* CustomBondForce* force = new CustomBondForce("0.5*r^2");
* atm->addForce(force);
* \endverbatim
*
* Expressions may involve the operators + (add), - (subtract), * (multiply), / (divide), and ^ (power), and the following
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, atan2, sinh, cosh, tanh, erf, erfc, min, max, abs, floor, ceil, step, delta,
* select. All trigonometric functions
* are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. delta(x) = 1 if x is 0, 0 otherwise.
* select(x,y,z) = z if x = 0, y otherwise.
*
* If instead of the energy expression the ATMForce constructor specifies the values of a series of parameters,
* the default energy expression is used:
*
* \verbatim embed:rst:leading-asterisk
* .. code-block::
*
* select(step(Direction), u0, u1) + ((Lambda2-Lambda1)/Alpha)*log(1+exp(-Alpha*(usc-Uh))) + Lambda2*usc + W0;
* usc = select(step(u-Ubcore), (Umax-Ubcore)*fsc+Ubcore, u), u);
* fsc = (z^Acore-1)/(z^Acore+1);
* z = 1 + 2*(y/Acore) + 2*(y/Acore)^2;
* y = (u-Ubcore)/(Umax-Ubcore);
* u = select(step(Direction), 1, -1)*(u1-u0)
* \endverbatim
*
* which is the same as the soft-core softplus alchemical potential energy function in the Azimi et al. paper above.
*
* The ATMForce is then added to the System as any other Force
*
* \verbatim embed:rst:leading-asterisk
* .. code-block:: cpp
*
* system.addForce(atmforce);
* \endverbatim
*
* after which it will be used for energy/force evaluations for molecular dynamics and energy optimization.
* You can call getPerturbationEnergy() to query the values of u0 and u1, which are needed for computing
* free energies.
*
* In most cases, particles are only displaced in one of the two states evaluated by this force. It computes the
* change in energy between the current particle coordinates (as stored in the Context) and the displaced coordinates.
* In some cases, it is useful to apply displacements to both states. You can do this by providing two displacement
* vectors to addParticle():
*
* \verbatim embed:rst:leading-asterisk
* .. code-block:: cpp
*
* atm->addParticle(Vec3(1, 0, 0), Vec3(-1, 0, 0));
* \endverbatim
*
* In this case, u1 will be computed after displacing the particle in the positive x direction, and
* u0 will be computed after displacing it in the negative x direction.
*
* This class also has the ability to compute derivatives of the potential energy with respect to global parameters.
* Call addEnergyParameterDerivative() to request that the derivative with respect to a particular parameter be
* computed. You can then query its value in a Context by calling getState() on it.
*/
class OPENMM_EXPORT ATMForce : public OpenMM::Force {
public:
/**
* Create an ATMForce object.
*
* @param energy an algebraic expression giving the energy of the system as a function
* of u0 and u1, the energies before and after displacement
*/
explicit ATMForce(const std::string& energy);
/**
* Create an ATMForce object with the default softplus energy expression. The values passed to
* this constructor are the default values of the global parameters for newly created Contexts.
* Their values can be changed by calling setParameter() on the Context using the parameter
* names defined by the Lambda1(), Lambda2(), etc. methods below.
*
* @param lambda1 the default value of the Lambda1 parameter (dimensionless). This should be
* a number between 0 and 1.
* @param lambda2 the default value of the Lambda2 parameter (dimensionless). This should be
* a number between 0 and 1.
* @param alpha the default value of the Alpha parameter (kJ/mol)^-1
* @param uh the default value of the Uh parameter (kJ/mol)
* @param w0 the default value of the W0 parameter (kJ/mol)
* @param umax the default value of the Umax parameter (kJ/mol)
* @param ubcore the default value of the Ubcore parameter (kJ/mol)
* @param acore the default value of the Acore parameter dimensionless)
* @param direction the default value of the Direction parameter (dimensionless). This should be
* either 1 for the forward transfer, or -1 for the backward transfer.
*/
ATMForce(double lambda1, double lambda2, double alpha, double uh, double w0, double umax, double ubcore, double acore, double direction);
~ATMForce();
/**
* Get the number of particles managed by ATMForce.
*
* This should be the same number of particles as the System
*/
int getNumParticles() const {
return particles.size();
}
/**
* Get the number of Forces included in the ATMForce.
*/
int getNumForces() const {
return forces.size();
}
/**
* Get the number of global parameters that the interaction depends on.
*/
int getNumGlobalParameters() const {
return globalParameters.size();
}
/**
* Get the number of global parameters with respect to which the derivative of the energy
* should be computed.
*/
int getNumEnergyParameterDerivatives() const {
return energyParameterDerivatives.size();
}
/**
* Get the algebraic expression that gives the energy of the system
*/
const std::string& getEnergyFunction() const;
/**
* Set the algebraic expression that gives the energy of the system
*/
void setEnergyFunction(const std::string& energy);
/**
* Add a Force whose energy will be computed by the ATMForce.
*
* @param force the Force to the be added, which should have been created on the heap with the
* "new" operator. The ATMForce takes over ownership of it, and deletes the Force when the
* ATMForce itself is deleted.
* @return The index within ATMForce of the force that was added
*/
int addForce(Force* force);
/**
* return the force from index
*/
Force& getForce(int index) const;
/**
* Add a particle to the force.
*
* All of the particles in the System must be added to the ATMForce in the same order
* as they appear in the System.
*
* @param displacement1 the displacement of the particle for the target state in nm
* @param displacement0 the displacement of the particle for the initial state in nm
* @return the index of the particle that was added
*/
int addParticle(const Vec3& displacement1, const Vec3& displacement0=Vec3());
/**
* Get the parameters for a particle
*
* @param index the index in the force for the particle for which to get parameters
* @param displacement1 the displacement of the particle for the target state in nm
* @param displacement0 the displacement of the particle for the initial state in nm
*/
void getParticleParameters(int index, Vec3& displacement1, Vec3& displacement0) const;
/**
* Set the parameters for a particle
*
* @param index the index in the force of the particle for which to set parameters
* @param displacement1 the displacement of the particle for the target state in nm
* @param displacement0 the displacement of the particle for the initial state in nm
*/
void setParticleParameters(int index, const Vec3& displacement1, const Vec3& displacement0=Vec3());
/**
* Add a new global parameter that the interaction may depend on. The default value provided to
* this method is the initial value of the parameter in newly created Contexts. You can change
* the value at any time by calling setParameter() on the Context.
*
* @param name the name of the parameter
* @param defaultValue the default value of the parameter
* @return the index of the parameter that was added
*/
int addGlobalParameter(const std::string& name, double defaultValue);
/**
* Get the name of a global parameter.
*
* @param index the index of the parameter for which to get the name
* @return the parameter name
*/
const std::string& getGlobalParameterName(int index) const;
/**
* Set the name of a global parameter.
*
* @param index the index of the parameter for which to set the name
* @param name the name of the parameter
*/
void setGlobalParameterName(int index, const std::string& name);
/**
* Get the default value of a global parameter.
*
* @param index the index of the parameter for which to get the default value
* @return the parameter default value
*/
double getGlobalParameterDefaultValue(int index) const;
/**
* Set the default value of a global parameter.
*
* @param index the index of the parameter for which to set the default value
* @param defaultValue the default value of the parameter
*/
void setGlobalParameterDefaultValue(int index, double defaultValue);
/**
* Request that this Force compute the derivative of its energy with respect to a global parameter.
* The parameter must have already been added with addGlobalParameter().
*
* @param name the name of the parameter
*/
void addEnergyParameterDerivative(const std::string& name);
/**
* Get the name of a global parameter with respect to which this Force should compute the
* derivative of the energy.
*
* @param index the index of the parameter derivative, between 0 and getNumEnergyParameterDerivatives()
* @return the parameter name
*/
const std::string& getEnergyParameterDerivativeName(int index) const;
/**
* Update the per-particle parameters in a Context to match those stored in this Force object. This method
* should be called after updating parameters with setParticleParameters() to copy them over to the Context.
* The only information this method updates is the values of per-particle parameters. The number of particles
* cannot be changed.
*/
void updateParametersInContext(Context& context);
/**
* Returns whether or not this force makes use of periodic boundary conditions.
*/
bool usesPeriodicBoundaryConditions() const;
/**
* Returns the current perturbation energy.
*
* @param context the Context for which to return the energy
* @param u1 on exit, the energy of the displaced state
* @param u0 on exit, the energy of the non-displaced state
* @param energy on exit, the value of this force's energy function
*/
void getPerturbationEnergy(Context& context, double& u1, double& u0, double& energy);
/**
* Returns the name of the global parameter corresponding to lambda1. The value assigned to this
* parameter should be a number between 0 and 1.
*/
static const std::string& Lambda1() {
static const std::string key = "Lambda1";
return key;
}
/**
* Returns the name of the global parameter corresponding to lambda2. The value assigned to this
* parameter should be a number between 0 and 1.
*/
static const std::string& Lambda2() {
static const std::string key = "Lambda2";
return key;
}
/**
* Returns the name of the global parameter corresponding to alpha. The value assigned to this
* parameter should be in units of (kJ/mol)^-1.
*/
static const std::string& Alpha() {
static const std::string key = "Alpha";
return key;
}
/**
* Returns the name of the global parameter corresponding to uh. The value assigned to this
* parameter should be in units of (kJ/mol).
*/
static const std::string& Uh() {
static const std::string key = "Uh";
return key;
}
/**
* Returns the name of the global parameter corresponding to w0. The value assigned to this
* parameter should be in units of (kJ/mol).
*/
static const std::string& W0() {
static const std::string key = "W0";
return key;
}
/**
* Returns the name of the global parameter corresponding to umax. The value assigned to this
* parameter should be in units of (kJ/mol).
*/
static const std::string& Umax() {
static const std::string key = "Umax";
return key;
}
/**
* Returns the name of the global parameter corresponding to ubcore. The value assigned to this
* parameter should be in units of (kJ/mol).
*/
static const std::string& Ubcore() {
static const std::string key = "Ubcore";
return key;
}
/**
* Returns the name of the global parameter corresponding to acore.
*/
static const std::string& Acore() {
static const std::string key = "Acore";
return key;
}
/**
* Returns the name of the global parameter corresponding to direction. The value assigned to
* this parameter should be either 1 for the forward transfer, or -1 for the backward transfer.
*/
static const std::string& Direction() {
static const std::string key = "Direction";
return key;
}
protected:
ForceImpl* createImpl() const;
private:
class ParticleInfo;
class GlobalParameterInfo;
std::string energyExpression;
std::vector<GlobalParameterInfo> globalParameters;
std::vector<Force *> forces;
std::vector<ParticleInfo> particles;
std::vector<int> energyParameterDerivatives;
};
/**
* This is an internal class used to record information about a particle.
* @private
*/
class ATMForce::ParticleInfo {
public:
int index;
Vec3 displacement1, displacement0;
ParticleInfo() : index(-1) {
}
ParticleInfo(int index) : index(index) {
}
ParticleInfo(int index, Vec3 displacement1, Vec3 displacement0) :
index(index), displacement1(displacement1), displacement0(displacement0) {
}
};
/**
* This is an internal class used to record information about a global parameter.
* @private
*/
class ATMForce::GlobalParameterInfo {
public:
std::string name;
double defaultValue;
GlobalParameterInfo() {
}
GlobalParameterInfo(const std::string& name, double defaultValue) : name(name), defaultValue(defaultValue) {
}
};
} // namespace OpenMM
#endif /*OPENMM_ATMFORCE_H_*/
#ifndef OPENMM_ATMFORCEFORCEIMPL_H_
#define OPENMM_ATMFORCEFORCEIMPL_H_
#include "openmm/ATMForce.h"
#include "openmm/internal/ForceImpl.h"
#include "openmm/Kernel.h"
#include "openmm/System.h"
#include "openmm/VerletIntegrator.h"
#include "openmm/internal/windowsExport.h"
#include "lepton/CompiledExpression.h"
#include <utility>
#include <set>
#include <string>
namespace OpenMM {
/**
* This is the internal implementation of ATMForce.
*/
class OPENMM_EXPORT ATMForceImpl : public ForceImpl {
public:
ATMForceImpl(const ATMForce& owner);
~ATMForceImpl();
void initialize(ContextImpl& context);
const ATMForce& getOwner() const {
return owner;
}
void updateContextState(ContextImpl& context, bool& forcesInvalid) {
// This force field doesn't update the state directly.
}
double calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups);
std::map<std::string, double> getDefaultParameters();
std::vector<std::string> getKernelNames();
std::vector<std::pair<int, int> > getBondedParticles() const;
void updateParametersInContext(ContextImpl& context);
void getPerturbationEnergy(ContextImpl& context, double& u1, double& u0, double& energy);
private:
const ATMForce& owner;
Kernel kernel;
System innerSystem0, innerSystem1;
VerletIntegrator innerIntegrator0, innerIntegrator1;
Context *innerContext0, *innerContext1;
Lepton::CompiledExpression energyExpression, u0DerivExpression, u1DerivExpression;
double state0Energy, state1Energy, combinedEnergy;
std::vector<std::string> globalParameterNames, paramDerivNames;
std::vector<double> globalValues;
std::vector<Lepton::CompiledExpression> paramDerivExpressions;
void copySystem(ContextImpl& context, const System& system, System& innerSystem);
};
} // namespace OpenMM
#endif /*OPENMM_ATMFORCEIMPL_H_*/
/* -------------------------------------------------------------------------- *
* OpenMM's Alchemical Transfer Force *
* -------------------------------------------------------------------------- *
* This is a Force of the OpenMM molecular simulation toolkit *
* that implements the Alchemical Transfer Potential *
* for absolute and relative binding free energy estimation *
* (https://doi.org/10.1021/acs.jcim.1c01129). The code is derived from the *
* ATMMetaForce plugin *
* https://github.com/Gallicchio-Lab/openmm-atmmetaforce-plugin *
* with support from the National Science Foundation CAREER 1750511 *
* *
* Portions copyright (c) 2021-2023 by the Authors *
* Authors: Emilio Gallicchio *
* Contributors: Peter Eastman *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#include "openmm/ATMForce.h"
#include "openmm/Force.h"
#include "openmm/serialization/XmlSerializer.h"
#include "openmm/internal/ATMForceImpl.h"
#include "openmm/OpenMMException.h"
#include "openmm/internal/AssertionUtilities.h"
#include <iostream>
#include <cmath>
#include <string>
using namespace OpenMM;
using namespace std;
ATMForce::ATMForce(const string& energy) : energyExpression(energy) {
}
ATMForce::ATMForce(double lambda1, double lambda2, double alpha, double uh, double w0, double umax, double ubcore, double acore, double direction) {
if (alpha < 0)
throw OpenMMException("ATMForce: alpha cannot be negative");
if (lambda1 != lambda2 && alpha == 0)
throw OpenMMException("ATMForce: alpha must be positive when lambda1 and lambda2 are different");
if (umax < ubcore)
throw OpenMMException("ATMForce: umax cannot be less than ubcore");
if (acore < 0)
throw OpenMMException("ATMForce: acore cannot be negative");
if (direction != 1.0 && direction != -1.0)
throw OpenMMException("ATMForce: direction must be either 1 or -1");
string referencePotExpression = "select(step(Direction), u0, u1) + ";
string alchemicalPotExpression = "select(Lambda2-Lambda1 , ((Lambda2-Lambda1)/Alpha)*log(1+exp(-Alpha*(usc-Uh))) + Lambda2*usc + W0, Lambda2*usc + W0);";
string softCoreExpression = "usc = select(Acore, select(step(u-Ubcore), (Umax-Ubcore)*fsc+Ubcore, u), u);"
"fsc = (z^Acore-1)/(z^Acore+1);"
"z = 1 + 2*(y/Acore) + 2*(y/Acore)^2;"
"y = (u-Ubcore)/(Umax-Ubcore);"
"u = select(step(Direction), 1, -1)*(u1-u0)";
setEnergyFunction(referencePotExpression + alchemicalPotExpression + softCoreExpression);
addGlobalParameter(Lambda1(), lambda1);
addGlobalParameter(Lambda2(), lambda2);
addGlobalParameter(Alpha(), alpha);
addGlobalParameter(Uh(), uh);
addGlobalParameter(W0(), w0);
addGlobalParameter(Umax(), umax);
addGlobalParameter(Ubcore(), ubcore);
addGlobalParameter(Acore(), acore);
addGlobalParameter(Direction(), direction);
}
ATMForce::~ATMForce() {
for (Force* force : forces)
delete force;
}
const string& ATMForce::getEnergyFunction() const {
return energyExpression;
}
void ATMForce::setEnergyFunction(const std::string& energy) {
energyExpression = energy;
}
int ATMForce::addParticle(const Vec3& displacement1, const Vec3& displacement0) {
particles.push_back(ParticleInfo(particles.size(), displacement1, displacement0));
return particles.size()-1;
}
void ATMForce::getParticleParameters(int index, Vec3& displacement1, Vec3& displacement0) const {
ASSERT_VALID_INDEX(index, particles);
displacement1 = particles[index].displacement1;
displacement0 = particles[index].displacement0;
}
void ATMForce::setParticleParameters(int index, const Vec3& displacement1, const Vec3& displacement0) {
ASSERT_VALID_INDEX(index, particles);
particles[index].displacement1 = displacement1;
particles[index].displacement0 = displacement0;
}
int ATMForce::addForce(Force* force) {
forces.push_back(force);
return forces.size()-1;
}
Force& ATMForce::getForce(int index) const {
ASSERT_VALID_INDEX(index, forces);
return *forces[index];
}
int ATMForce::addGlobalParameter(const string& name, double defaultValue) {
globalParameters.push_back(GlobalParameterInfo(name, defaultValue));
return globalParameters.size()-1;
}
const string& ATMForce::getGlobalParameterName(int index) const {
ASSERT_VALID_INDEX(index, globalParameters);
return globalParameters[index].name;
}
void ATMForce::setGlobalParameterName(int index, const string& name) {
ASSERT_VALID_INDEX(index, globalParameters);
globalParameters[index].name = name;
}
double ATMForce::getGlobalParameterDefaultValue(int index) const {
ASSERT_VALID_INDEX(index, globalParameters);
return globalParameters[index].defaultValue;
}
void ATMForce::setGlobalParameterDefaultValue(int index, double defaultValue) {
ASSERT_VALID_INDEX(index, globalParameters);
globalParameters[index].defaultValue = defaultValue;
}
void ATMForce::addEnergyParameterDerivative(const string& name) {
for (int i = 0; i < globalParameters.size(); i++)
if (name == globalParameters[i].name) {
energyParameterDerivatives.push_back(i);
return;
}
throw OpenMMException(string("addEnergyParameterDerivative: Unknown global parameter '"+name+"'"));
}
const string& ATMForce::getEnergyParameterDerivativeName(int index) const {
ASSERT_VALID_INDEX(index, energyParameterDerivatives);
return globalParameters[energyParameterDerivatives[index]].name;
}
ForceImpl* ATMForce::createImpl() const {
return new ATMForceImpl(*this);
}
void ATMForce::updateParametersInContext(OpenMM::Context& context) {
dynamic_cast<ATMForceImpl&>(getImplInContext(context)).updateParametersInContext(getContextImpl(context));
}
bool ATMForce::usesPeriodicBoundaryConditions() const {
for (auto& force : forces)
if (force->usesPeriodicBoundaryConditions())
return true;
return false;
}
void ATMForce::getPerturbationEnergy(OpenMM::Context& context, double& u0, double& u1, double& energy) {
dynamic_cast<ATMForceImpl&>(getImplInContext(context)).getPerturbationEnergy(getContextImpl(context), u0, u1, energy);
}
/* -------------------------------------------------------------------------- *
* OpenMM's Alchemical Transfer Force *
* -------------------------------------------------------------------------- *
* This is a Force of the OpenMM molecular simulation toolkit *
* that implements the Alchemical Transfer Potential *
* for absolute and relative binding free energy estimation *
* (https://doi.org/10.1021/acs.jcim.1c01129). The code is derived from the *
* ATMMetaForce plugin *
* https://github.com/Gallicchio-Lab/openmm-atmmetaforce-plugin *
* with support from the National Science Foundation CAREER 1750511 *
* *
* Portions copyright (c) 2021-2023 by the Authors *
* Authors: Emilio Gallicchio *
* Contributors: Peter Eastman *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#ifdef WIN32
#define _USE_MATH_DEFINES // Needed to get M_PI
#endif
#include "openmm/internal/ATMForceImpl.h"
#include "openmm/NonbondedForce.h"
#include "openmm/kernels.h"
#include "openmm/serialization/XmlSerializer.h"
#include "openmm/Vec3.h"
#include "openmm/OpenMMException.h"
#include "openmm/internal/ContextImpl.h"
#include "lepton/ParsedExpression.h"
#include "lepton/Parser.h"
#include <cmath>
#include <map>
#include <set>
#include <sstream>
#include <iostream>
using namespace OpenMM;
using namespace std;
ATMForceImpl::ATMForceImpl(const ATMForce& owner) : owner(owner), innerIntegrator0(1.0), innerIntegrator1(1.0) {
Lepton::ParsedExpression expr = Lepton::Parser::parse(owner.getEnergyFunction()).optimize();
energyExpression = expr.createCompiledExpression();
u0DerivExpression = expr.differentiate("u0").createCompiledExpression();
u1DerivExpression = expr.differentiate("u1").createCompiledExpression();
for (int i = 0; i < owner.getNumGlobalParameters(); i++)
globalParameterNames.push_back(owner.getGlobalParameterName(i));
globalValues.resize(globalParameterNames.size());
map<string, double*> variableLocations;
variableLocations["u0"] = &state0Energy;
variableLocations["u1"] = &state1Energy;
for (int i = 0; i < globalParameterNames.size(); i++)
variableLocations[globalParameterNames[i]] = &globalValues[i];
energyExpression.setVariableLocations(variableLocations);
u0DerivExpression.setVariableLocations(variableLocations);
u1DerivExpression.setVariableLocations(variableLocations);
for (int i = 0; i < owner.getNumEnergyParameterDerivatives(); i++) {
string name = owner.getEnergyParameterDerivativeName(i);
paramDerivNames.push_back(name);
paramDerivExpressions.push_back(expr.differentiate(name).createCompiledExpression());
paramDerivExpressions[i].setVariableLocations(variableLocations);
}
}
ATMForceImpl::~ATMForceImpl() {
}
void ATMForceImpl::copySystem(ContextImpl& context, const OpenMM::System& system, OpenMM::System& innerSystem) {
//copy particles
for (int i = 0; i < system.getNumParticles(); i++)
innerSystem.addParticle(system.getParticleMass(i));
//copy periodic box dimensions
Vec3 a, b, c;
system.getDefaultPeriodicBoxVectors(a, b, c);
innerSystem.setDefaultPeriodicBoxVectors(a, b, c);
// Add forces to the inner contexts
for (int i = 0; i < owner.getNumForces(); i++) {
const Force &force = owner.getForce(i);
innerSystem.addForce(XmlSerializer::clone<Force>(force));
}
}
void ATMForceImpl::initialize(ContextImpl& context) {
const OpenMM::System& system = context.getSystem();
copySystem(context, system, innerSystem0);
copySystem(context, system, innerSystem1);
// Create the inner context.
innerContext0 = context.createLinkedContext(innerSystem0, innerIntegrator0);
innerContext1 = context.createLinkedContext(innerSystem1, innerIntegrator1);
vector<Vec3> positions(system.getNumParticles(), Vec3());
innerContext0->setPositions(positions);
innerContext1->setPositions(positions);
// Create the kernel.
kernel = context.getPlatform().createKernel(CalcATMForceKernel::Name(), context);
kernel.getAs<CalcATMForceKernel>().initialize(context.getSystem(), owner);
}
double ATMForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) {
if ((groups & (1 << owner.getForceGroup())) == 0)
return 0.0;
ContextImpl& innerContextImpl0 = getContextImpl(*innerContext0);
ContextImpl& innerContextImpl1 = getContextImpl(*innerContext1);
// Copy the coordinates etc. from the context to the inner contexts
kernel.getAs<CalcATMForceKernel>().copyState(context, innerContextImpl0, innerContextImpl1);
// Evaluate energy and forces for the two systems
state0Energy = innerContextImpl0.calcForcesAndEnergy(includeForces, true);
state1Energy = innerContextImpl1.calcForcesAndEnergy(includeForces, true);
// Compute the alchemical energy and forces.
for (int i = 0; i < globalParameterNames.size(); i++)
globalValues[i] = context.getParameter(globalParameterNames[i]);
combinedEnergy = energyExpression.evaluate();
if (includeForces) {
double dEdu0 = u0DerivExpression.evaluate();
double dEdu1 = u1DerivExpression.evaluate();
map<string, double> energyParamDerivs;
for (int i = 0; i < paramDerivExpressions.size(); i++)
energyParamDerivs[paramDerivNames[i]] += paramDerivExpressions[i].evaluate();
kernel.getAs<CalcATMForceKernel>().applyForces(context, innerContextImpl0, innerContextImpl1, dEdu0, dEdu1, energyParamDerivs);
}
return (includeEnergy ? combinedEnergy : 0.0);
}
std::map<std::string, double> ATMForceImpl::getDefaultParameters() {
map<string, double> parameters;
parameters.insert(innerContext0->getParameters().begin(), innerContext0->getParameters().end());
for (int i = 0; i < owner.getNumGlobalParameters(); i++)
parameters[owner.getGlobalParameterName(i)] = owner.getGlobalParameterDefaultValue(i);
return parameters;
}
std::vector<std::string> ATMForceImpl::getKernelNames() {
std::vector<std::string> names;
names.push_back(CalcATMForceKernel::Name());
return names;
}
vector<pair<int, int> > ATMForceImpl::getBondedParticles() const {
vector<pair<int, int> > bonds;
const ContextImpl& innerContextImpl = getContextImpl(*innerContext0);
for (auto& impl : innerContextImpl.getForceImpls()) {
for (auto& bond : impl->getBondedParticles())
bonds.push_back(bond);
}
return bonds;
}
void ATMForceImpl::updateParametersInContext(ContextImpl& context) {
kernel.getAs<CalcATMForceKernel>().copyParametersToContext(context, owner);
}
void ATMForceImpl::getPerturbationEnergy(ContextImpl& context, double& u1, double& u0, double& energy) {
calcForcesAndEnergy(context, false, true, -1);
u0 = state0Energy;
u1 = state1Energy;
energy = combinedEnergy;
}
...@@ -1586,6 +1586,75 @@ private: ...@@ -1586,6 +1586,75 @@ private:
std::vector<mm_int4> lastPosCellOffsets; std::vector<mm_int4> lastPosCellOffsets;
}; };
/**
* This kernel is invoked by ATMForce to calculate the forces acting on the system and the energy of the system.
*/
class CommonCalcATMForceKernel : public CalcATMForceKernel {
public:
CommonCalcATMForceKernel(std::string name, const Platform& platform, ComputeContext& cc): CalcATMForceKernel(name, platform), hasInitializedKernel(false), cc(cc) {
}
~CommonCalcATMForceKernel();
/**
* Initialize the kernel.
*
* @param system the System this kernel will be applied to
* @param force the ATMForce this kernel will be used for
*/
void initialize(const System& system, const ATMForce& force);
/**
* Scale the forces from the inner contexts and apply them to the main context.
*
* @param context the context in which to execute this kernel
* @param innerContext1 the first inner context
* @param innerContext2 the second inner context
* @param dEdu0 the derivative of the final energy with respect to the first inner context's energy
* @param dEdu1 the derivative of the final energy with respect to the second inner context's energy
* @param energyParamDerivs derivatives of the final energy with respect to global parameters
*/
void applyForces(ContextImpl& context, ContextImpl& innerContext0, ContextImpl& innerContext1,
double dEdu0, double dEdu1, const std::map<std::string, double>& energyParamDerivs);
/**
* Copy changed parameters over to a context.
*
* @param context the context to copy parameters to
* @param force the ATMForce to copy the parameters from
*/
void copyParametersToContext(ContextImpl& context, const ATMForce& force);
/**
* Copy state information to the inner contexts.
*
* @param context the context in which to execute this kernel
* @param innerContext1 the first context created by the ATMForce for computing displaced energy
* @param innerContext2 the second context created by the ATMForce for computing displaced energy
*/
void copyState(ContextImpl& context, ContextImpl& innerContext1, ContextImpl& innerContext2);
/**
* Get the ComputeContext corresponding to the inner Context.
*/
virtual ComputeContext& getInnerComputeContext(ContextImpl& innerContext) = 0;
private:
class ForceInfo;
class ReorderListener;
void initKernels(ContextImpl& context, ContextImpl& innerContext0, ContextImpl& innerContext1);
bool hasInitializedKernel;
ComputeContext& cc;
std::vector<mm_float4> displVector1;
std::vector<mm_float4> displVector0;
ComputeArray displ1;
ComputeArray displ0;
ComputeKernel copyStateKernel;
ComputeKernel hybridForceKernel;
int numParticles;
};
} // namespace OpenMM } // namespace OpenMM
#endif /*OPENMM_COMMONKERNELS_H_*/ #endif /*OPENMM_COMMONKERNELS_H_*/
...@@ -7788,3 +7788,189 @@ void CommonApplyMonteCarloBarostatKernel::restoreCoordinates(ContextImpl& contex ...@@ -7788,3 +7788,189 @@ void CommonApplyMonteCarloBarostatKernel::restoreCoordinates(ContextImpl& contex
if (atomsWereReordered || cc.getAtomsWereReordered()) if (atomsWereReordered || cc.getAtomsWereReordered())
cc.setAtomIndex(lastAtomOrder); cc.setAtomIndex(lastAtomOrder);
} }
class CommonCalcATMForceKernel::ForceInfo : public ComputeForceInfo {
public:
ForceInfo(ComputeForceInfo& force) : force(force) {
}
bool areParticlesIdentical(int particle1, int particle2) {
return force.areParticlesIdentical(particle1, particle2);
}
int getNumParticleGroups() {
return force.getNumParticleGroups();
}
void getParticlesInGroup(int index, vector<int>& particles) {
force.getParticlesInGroup(index, particles);
}
bool areGroupsIdentical(int group1, int group2) {
return force.areGroupsIdentical(group1, group2);
}
private:
ComputeForceInfo& force;
};
class CommonCalcATMForceKernel::ReorderListener : public ComputeContext::ReorderListener {
public:
ReorderListener(ComputeContext& cc, vector<mm_float4>& displVector1, ArrayInterface& displ1,
vector<mm_float4>& displVector0, ArrayInterface& displ0) :
cc(cc), displVector1(displVector1), displ1(displ1), displVector0(displVector0), displ0(displ0) {
}
void execute() {
const vector<int>& id = cc.getAtomIndex();
vector<mm_float4> newDisplVectorContext1(cc.getPaddedNumAtoms());
vector<mm_float4> newDisplVectorContext0(cc.getPaddedNumAtoms());
for (int i = 0; i < cc.getNumAtoms(); i++) {
newDisplVectorContext1[i] = displVector1[id[i]];
newDisplVectorContext0[i] = displVector0[id[i]];
}
displ1.upload(newDisplVectorContext1);
displ0.upload(newDisplVectorContext0);
}
private:
ComputeContext& cc;
ArrayInterface& displ1;
ArrayInterface& displ0;
std::vector<mm_float4> displVector1;
std::vector<mm_float4> displVector0;
};
CommonCalcATMForceKernel::~CommonCalcATMForceKernel() {
}
void CommonCalcATMForceKernel::initialize(const System& system, const ATMForce& force) {
ContextSelector selector(cc);
numParticles = force.getNumParticles();
if (numParticles == 0)
return;
displVector1.resize(cc.getPaddedNumAtoms(), mm_float4(0, 0, 0, 0));
displVector0.resize(cc.getPaddedNumAtoms(), mm_float4(0, 0, 0, 0));
vector<mm_float4> displVectorContext1(cc.getPaddedNumAtoms(), mm_float4(0, 0, 0, 0));
vector<mm_float4> displVectorContext0(cc.getPaddedNumAtoms(), mm_float4(0, 0, 0, 0));
for (int i = 0; i < numParticles; i++) {
Vec3 displacement1, displacement0;
force.getParticleParameters(i, displacement1, displacement0);
displVector1[i] = mm_float4(displacement1[0], displacement1[1], displacement1[2], 0);
displVector0[i] = mm_float4(displacement0[0], displacement0[1], displacement0[2], 0);
}
const vector<int>& id = cc.getAtomIndex();
for (int i = 0; i < numParticles; i++)
displVectorContext1[i] = displVector1[id[i]];
displ1.initialize<mm_float4>(cc, cc.getPaddedNumAtoms(), "displ1");
displ1.upload(displVectorContext1);
for (int i = 0; i < numParticles; i++)
displVectorContext0[i] = displVector0[id[i]];
displ0.initialize<mm_float4>(cc, cc.getPaddedNumAtoms(), "displ0");
displ0.upload(displVectorContext0);
for (int i = 0; i < force.getNumEnergyParameterDerivatives(); i++)
cc.addEnergyParameterDerivative(force.getEnergyParameterDerivativeName(i));
cc.addForce(new ComputeForceInfo());
}
void CommonCalcATMForceKernel::initKernels(ContextImpl& context, ContextImpl& innerContext0, ContextImpl& innerContext1) {
if (!hasInitializedKernel) {
hasInitializedKernel = true;
//inner contexts
ComputeContext& cc0 = getInnerComputeContext(innerContext0);
ComputeContext& cc1 = getInnerComputeContext(innerContext1);
//initialize the listener, this reorders the displacement vectors
ReorderListener* listener = new ReorderListener(cc, displVector1, displ1, displVector0, displ0);
cc.addReorderListener(listener);
listener->execute();
//create CopyState kernel
ComputeProgram program = cc.compileProgram(CommonKernelSources::atmforce);
copyStateKernel = program->createKernel("copyState");
copyStateKernel->addArg(numParticles);
copyStateKernel->addArg(cc.getPosq());
copyStateKernel->addArg(cc0.getPosq());
copyStateKernel->addArg(cc1.getPosq());
copyStateKernel->addArg(displ0);
copyStateKernel->addArg(displ1);
if (cc.getUseMixedPrecision()) {
copyStateKernel->addArg(cc.getPosqCorrection());
copyStateKernel->addArg(cc0.getPosqCorrection());
copyStateKernel->addArg(cc1.getPosqCorrection());
}
//create the HybridForce kernel
hybridForceKernel = program->createKernel("hybridForce");
hybridForceKernel->addArg(numParticles);
hybridForceKernel->addArg(cc.getPaddedNumAtoms());
hybridForceKernel->addArg(cc.getLongForceBuffer());
hybridForceKernel->addArg(cc0.getLongForceBuffer());
hybridForceKernel->addArg(cc1.getLongForceBuffer());
hybridForceKernel->addArg();
hybridForceKernel->addArg();
cc0.addForce(new ComputeForceInfo());
cc1.addForce(new ComputeForceInfo());
}
}
void CommonCalcATMForceKernel::applyForces(ContextImpl& context, ContextImpl& innerContext0, ContextImpl& innerContext1,
double dEdu0, double dEdu1, const map<string, double>& energyParamDerivs) {
ContextSelector selector(cc);
initKernels(context, innerContext0, innerContext1);
if (cc.getUseDoublePrecision()) {
hybridForceKernel->setArg(5, dEdu0);
hybridForceKernel->setArg(6, dEdu1);
}
else {
hybridForceKernel->setArg(5, (float) dEdu0);
hybridForceKernel->setArg(6, (float) dEdu1);
}
hybridForceKernel->execute(numParticles);
map<string, double>& derivs = cc.getEnergyParamDerivWorkspace();
for (auto deriv : energyParamDerivs)
derivs[deriv.first] += deriv.second;
}
void CommonCalcATMForceKernel::copyState(ContextImpl& context,
ContextImpl& innerContext0, ContextImpl& innerContext1) {
ContextSelector selector(cc);
initKernels(context, innerContext0, innerContext1);
copyStateKernel->execute(numParticles);
Vec3 a, b, c;
context.getPeriodicBoxVectors(a, b, c);
innerContext0.setPeriodicBoxVectors(a, b, c);
innerContext0.setTime(context.getTime());
innerContext1.setPeriodicBoxVectors(a, b, c);
innerContext1.setTime(context.getTime());
map<string, double> innerParameters0 = innerContext0.getParameters();
for (auto& param : innerParameters0)
innerContext0.setParameter(param.first, context.getParameter(param.first));
map<string, double> innerParameters1 = innerContext1.getParameters();
for (auto& param : innerParameters1)
innerContext1.setParameter(param.first, context.getParameter(param.first));
}
void CommonCalcATMForceKernel::copyParametersToContext(ContextImpl& context, const ATMForce& force) {
if (force.getNumParticles() != numParticles)
throw OpenMMException("copyParametersToContext: The number of ATMMetaForce particles has changed");
displVector1.resize(cc.getPaddedNumAtoms());
displVector0.resize(cc.getPaddedNumAtoms());
for (int i = 0; i < numParticles; i++) {
Vec3 displacement1, displacement0;
force.getParticleParameters(i, displacement1, displacement0);
displVector1[i] = mm_float4(displacement1[0], displacement1[1], displacement1[2], 0);
displVector0[i] = mm_float4(displacement0[0], displacement0[1], displacement0[2], 0);
}
const vector<int>& id = cc.getAtomIndex();
vector<mm_float4> displVectorContext1(cc.getPaddedNumAtoms(), mm_float4(0, 0, 0, 0));
vector<mm_float4> displVectorContext0(cc.getPaddedNumAtoms(), mm_float4(0, 0, 0, 0));
for (int i = 0; i < numParticles; i++) {
displVectorContext1[i] = displVector1[id[i]];
displVectorContext0[i] = displVector0[id[i]];
}
displ1.upload(displVectorContext1);
displ0.upload(displVectorContext0);
}
KERNEL void hybridForce(int numParticles,
int paddedNumParticles,
GLOBAL mm_long* RESTRICT force,
GLOBAL mm_long* RESTRICT force0,
GLOBAL mm_long* RESTRICT force1,
real dEdu0,
real dEdu1) {
for (int i = GLOBAL_ID; i < numParticles; i += GLOBAL_SIZE) {
force[i] += (mm_long) (dEdu0*force0[i] + dEdu1*force1[i]);
force[i+paddedNumParticles] += (mm_long) (dEdu0*force0[i+paddedNumParticles] + dEdu1*force1[i+paddedNumParticles]);
force[i+paddedNumParticles*2] += (mm_long) (dEdu0*force0[i+paddedNumParticles*2] + dEdu1*force1[i+paddedNumParticles*2]);
}
}
KERNEL void copyState(int numParticles,
GLOBAL real4* RESTRICT posq,
GLOBAL real4* RESTRICT posq0,
GLOBAL real4* RESTRICT posq1,
GLOBAL float4* RESTRICT displ0,
GLOBAL float4* RESTRICT displ1
#ifdef USE_MIXED_PRECISION
,
GLOBAL real4* RESTRICT posqCorrection,
GLOBAL real4* RESTRICT posq0Correction,
GLOBAL real4* RESTRICT posq1Correction
#endif
) {
for (int i = GLOBAL_ID; i < numParticles; i += GLOBAL_SIZE) {
real4 p0 = posq[i] + make_real4((real) displ0[i].x, (real) displ0[i].y, (real) displ0[i].z, 0);
real4 p1 = posq[i] + make_real4((real) displ1[i].x, (real) displ1[i].y, (real) displ1[i].z, 0);
p0.w = posq0[i].w;
p1.w = posq1[i].w;
posq0[i] = p0;
posq1[i] = p1;
#ifdef USE_MIXED_PRECISION
posq0Correction[i] = posqCorrection[i];
posq1Correction[i] = posqCorrection[i];
#endif
}
}
...@@ -368,6 +368,15 @@ public: ...@@ -368,6 +368,15 @@ public:
} }
}; };
class CudaCalcATMForceKernel : public CommonCalcATMForceKernel {
public:
CudaCalcATMForceKernel(std::string name, const Platform& platform, ComputeContext& cc) : CommonCalcATMForceKernel(name, platform, cc) {
}
ComputeContext& getInnerComputeContext(ContextImpl& innerContext) {
return *reinterpret_cast<CudaPlatform::PlatformData*>(innerContext.getPlatformData())->contexts[0];
}
};
} // namespace OpenMM } // namespace OpenMM
#endif /*OPENMM_CUDAKERNELS_H_*/ #endif /*OPENMM_CUDAKERNELS_H_*/
......
...@@ -111,6 +111,8 @@ KernelImpl* CudaKernelFactory::createKernelImpl(std::string name, const Platform ...@@ -111,6 +111,8 @@ KernelImpl* CudaKernelFactory::createKernelImpl(std::string name, const Platform
return new CommonCalcCustomCompoundBondForceKernel(name, platform, cu, context.getSystem()); return new CommonCalcCustomCompoundBondForceKernel(name, platform, cu, context.getSystem());
if (name == CalcCustomCVForceKernel::Name()) if (name == CalcCustomCVForceKernel::Name())
return new CudaCalcCustomCVForceKernel(name, platform, cu); return new CudaCalcCustomCVForceKernel(name, platform, cu);
if (name == CalcATMForceKernel::Name())
return new CudaCalcATMForceKernel(name, platform, cu);
if (name == CalcRMSDForceKernel::Name()) if (name == CalcRMSDForceKernel::Name())
return new CommonCalcRMSDForceKernel(name, platform, cu); return new CommonCalcRMSDForceKernel(name, platform, cu);
if (name == CalcCustomManyParticleForceKernel::Name()) if (name == CalcCustomManyParticleForceKernel::Name())
......
...@@ -92,6 +92,7 @@ CudaPlatform::CudaPlatform() { ...@@ -92,6 +92,7 @@ CudaPlatform::CudaPlatform() {
registerKernelFactory(CalcCustomCentroidBondForceKernel::Name(), factory); registerKernelFactory(CalcCustomCentroidBondForceKernel::Name(), factory);
registerKernelFactory(CalcCustomCompoundBondForceKernel::Name(), factory); registerKernelFactory(CalcCustomCompoundBondForceKernel::Name(), factory);
registerKernelFactory(CalcCustomCVForceKernel::Name(), factory); registerKernelFactory(CalcCustomCVForceKernel::Name(), factory);
registerKernelFactory(CalcATMForceKernel::Name(), factory);
registerKernelFactory(CalcRMSDForceKernel::Name(), factory); registerKernelFactory(CalcRMSDForceKernel::Name(), factory);
registerKernelFactory(CalcCustomManyParticleForceKernel::Name(), factory); registerKernelFactory(CalcCustomManyParticleForceKernel::Name(), factory);
registerKernelFactory(CalcGayBerneForceKernel::Name(), factory); registerKernelFactory(CalcGayBerneForceKernel::Name(), factory);
......
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2023 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#include "CudaTests.h"
#include "TestATMForce.h"
void runPlatformTests() {
}
...@@ -345,6 +345,19 @@ public: ...@@ -345,6 +345,19 @@ public:
} }
}; };
/**
* This kernel is invoked by ATMForce to calculate the forces acting on the system and the energy of the system.
*/
class OpenCLCalcATMForceKernel : public CommonCalcATMForceKernel {
public:
OpenCLCalcATMForceKernel(std::string name, const Platform& platform, ComputeContext& cc) : CommonCalcATMForceKernel(name, platform, cc) {
}
ComputeContext& getInnerComputeContext(ContextImpl& innerContext) {
return *reinterpret_cast<OpenCLPlatform::PlatformData*>(innerContext.getPlatformData())->contexts[0];
}
};
} // namespace OpenMM } // namespace OpenMM
#endif /*OPENMM_OPENCLKERNELS_H_*/ #endif /*OPENMM_OPENCLKERNELS_H_*/
...@@ -109,6 +109,8 @@ KernelImpl* OpenCLKernelFactory::createKernelImpl(std::string name, const Platfo ...@@ -109,6 +109,8 @@ KernelImpl* OpenCLKernelFactory::createKernelImpl(std::string name, const Platfo
return new CommonCalcCustomCompoundBondForceKernel(name, platform, cl, context.getSystem()); return new CommonCalcCustomCompoundBondForceKernel(name, platform, cl, context.getSystem());
if (name == CalcCustomCVForceKernel::Name()) if (name == CalcCustomCVForceKernel::Name())
return new OpenCLCalcCustomCVForceKernel(name, platform, cl); return new OpenCLCalcCustomCVForceKernel(name, platform, cl);
if (name == CalcATMForceKernel::Name())
return new OpenCLCalcATMForceKernel(name, platform, cl);
if (name == CalcRMSDForceKernel::Name()) if (name == CalcRMSDForceKernel::Name())
return new CommonCalcRMSDForceKernel(name, platform, cl); return new CommonCalcRMSDForceKernel(name, platform, cl);
if (name == CalcCustomManyParticleForceKernel::Name()) if (name == CalcCustomManyParticleForceKernel::Name())
......
...@@ -83,6 +83,7 @@ OpenCLPlatform::OpenCLPlatform() { ...@@ -83,6 +83,7 @@ OpenCLPlatform::OpenCLPlatform() {
registerKernelFactory(CalcCustomCentroidBondForceKernel::Name(), factory); registerKernelFactory(CalcCustomCentroidBondForceKernel::Name(), factory);
registerKernelFactory(CalcCustomCompoundBondForceKernel::Name(), factory); registerKernelFactory(CalcCustomCompoundBondForceKernel::Name(), factory);
registerKernelFactory(CalcCustomCVForceKernel::Name(), factory); registerKernelFactory(CalcCustomCVForceKernel::Name(), factory);
registerKernelFactory(CalcATMForceKernel::Name(), factory);
registerKernelFactory(CalcRMSDForceKernel::Name(), factory); registerKernelFactory(CalcRMSDForceKernel::Name(), factory);
registerKernelFactory(CalcCustomManyParticleForceKernel::Name(), factory); registerKernelFactory(CalcCustomManyParticleForceKernel::Name(), factory);
registerKernelFactory(CalcGayBerneForceKernel::Name(), factory); registerKernelFactory(CalcGayBerneForceKernel::Name(), factory);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment