Unverified Commit 3645e1ee authored by Emilio Gallicchio's avatar Emilio Gallicchio Committed by GitHub
Browse files

Variable distance-based displacements for ATMForce (#4776)



* Variable displacements based on particle positions

* set variable displacements when adding particles

* update documentation

* address compilation error in OpenMMFortranWrapper

* update python API tests

* fix stray 'and'

* addParticle() without arguments add a particle that is not displaced

* pack displacement particles into int4

* put back default displacement removed in error

* ATMForce interface with coordinate transformation objects

* revise variable displacement API

* documentation, formatting, serialization

* Fixed C and Fortran wrappers

* Fixed Python wrappers

* Fixed factory

* Sort files to ensure classes are listed in the correct order

* Converted APIUnits test to new ATMForce API

* write class name

* skip the documentation for forward declarations

* undo 9e91d0b since it does not fix the doc build

* remove temporary doc files for nested classes

* Clean away tabs

---------
Co-authored-by: default avatarPeter Eastman <peastman@stanford.edu>
parent 4a956a72
...@@ -56,6 +56,7 @@ add_custom_command( ...@@ -56,6 +56,7 @@ add_custom_command(
add_custom_command( add_custom_command(
OUTPUT "${CMAKE_BINARY_DIR}/api-c++/index.html" OUTPUT "${CMAKE_BINARY_DIR}/api-c++/index.html"
COMMAND "${CMAKE_COMMAND}" -E rm "${CMAKE_CURRENT_BINARY_DIR}/generated/*::*"
COMMAND "${PYTHON_EXECUTABLE}" -m sphinx . "${CMAKE_BINARY_DIR}/api-c++" -W --keep-going # Promote warnings to errors to catch undocumented classes COMMAND "${PYTHON_EXECUTABLE}" -m sphinx . "${CMAKE_BINARY_DIR}/api-c++" -W --keep-going # Promote warnings to errors to catch undocumented classes
WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}" WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}"
DEPENDS "${CMAKE_CURRENT_BINARY_DIR}/conf.py" DEPENDS "${CMAKE_CURRENT_BINARY_DIR}/conf.py"
......
...@@ -40,6 +40,7 @@ ...@@ -40,6 +40,7 @@
#include <openmm/Vec3.h> #include <openmm/Vec3.h>
#include <vector> #include <vector>
#include <string> #include <string>
#include <map>
#include "internal/windowsExport.h" #include "internal/windowsExport.h"
namespace OpenMM { namespace OpenMM {
...@@ -59,7 +60,7 @@ namespace OpenMM { ...@@ -59,7 +60,7 @@ namespace OpenMM {
* and please cite it to support our work if you use this software in your research. * and please cite it to support our work if you use this software in your research.
* *
* The ATMForce implements an arbitrary potential energy function that depends on the potential * The ATMForce implements an arbitrary potential energy function that depends on the potential
* energies (u0 and u1) of the system before and after a set of atoms are displaced by a specified amount. * energies (u0 and u1) of the system before and after a set of atoms are displaced by some amount.
* For example, you might displace a molecule from the solvent bulk to a receptor binding site to simulate * For example, you might displace a molecule from the solvent bulk to a receptor binding site to simulate
* a binding process. The potential energy function typically also depends on one or more parameters that * a binding process. The potential energy function typically also depends on one or more parameters that
* are dialed to implement alchemical transformations. * are dialed to implement alchemical transformations.
...@@ -69,9 +70,10 @@ namespace OpenMM { ...@@ -69,9 +70,10 @@ namespace OpenMM {
* of the variables u0 and u1. Then call addGlobalParameter() to define the parameters on which the potential energy expression depends. * of the variables u0 and u1. Then call addGlobalParameter() to define the parameters on which the potential energy expression depends.
* The values of global parameters may be modified during a simulation by calling Context::setParameter(). * The values of global parameters may be modified during a simulation by calling Context::setParameter().
* Next, call addForce() to add Force objects that define the terms of the potential energy function * Next, call addForce() to add Force objects that define the terms of the potential energy function
* that change upon displacement. Finally, call addParticle() to specify the displacement applied to * that change upon displacement. Finally, call addParticle() to specify the coordinate transformation applied to
* each particle. Displacements can be changed by calling setParticleParameters(). As any per-particle parameters, * each particle. Currently supported coordinate transformations consist of displacing the positions of particles by a fixed amount
* changes in displacements take effect only after calling updateParametersInContext(). * or by the offset of the positions between two given particles. As any per-particle parameters, changes in particle coordinate
* transformations take effect only after calling updateParametersInContext().
* *
* As an example, the following code creates an ATMForce based on the change in energy of * As an example, the following code creates an ATMForce based on the change in energy of
* two particles when the second particle is displaced by 1 nm in the x direction. * two particles when the second particle is displaced by 1 nm in the x direction.
...@@ -82,13 +84,34 @@ namespace OpenMM { ...@@ -82,13 +84,34 @@ namespace OpenMM {
* *
* ATMForce *atmforce = new ATMForce("u0 + Lambda*(u1 - u0)"); * ATMForce *atmforce = new ATMForce("u0 + Lambda*(u1 - u0)");
* atm->addGlobalParameter("Lambda", 0.5); * atm->addGlobalParameter("Lambda", 0.5);
* atm->addParticle(Vec3(0, 0, 0)); * atm->addParticle();
* atm->addParticle(Vec3(1, 0, 0)); * atm->addParticle(new ATMForce::FixedDisplacement(Vec3(1, 0, 0)));
* CustomBondForce* force = new CustomBondForce("0.5*r^2"); * CustomBondForce* force = new CustomBondForce("0.5*r^2");
* atm->addForce(force); * atm->addForce(force);
* \endverbatim * \endverbatim
* *
* Expressions may involve the operators + (add), - (subtract), * (multiply), / (divide), and ^ (power), and the following * Note that calling addParticle() without arguments is equivalent to a zero fixed displacement.
*
* In the example above, the displacement is specified by fixed lab-frame vector. ATMForce also supports variable displacements in internal
* system coordinates in terms of vector distance between specified particles. For example, if pos[] is the internal array holding the positions
* of the particles, the following code creates an ATMForce based on the change in energy when the first particle is displaced by the vector
* pos[2]-pos[1] going from the second particle to the third particle,
*
* \verbatim embed:rst:leading-asterisk
* .. code-block:: cpp
*
* ATMForce *atmforce = new ATMForce("u0 + Lambda*(u1 - u0)");
* atm->addGlobalParameter("Lambda", 0.5);
* atm->addParticle(new ATMForce::ParticleOffsetDisplacement(2, 1));
* atm->addParticle();
* atm->addParticle();
* CustomBondForce* force = new CustomBondForce("0.5*r^2");
* atm->addForce(force);
* \endverbatim
*
* where ParticleOffsetDisplacement is the class that describes this particular type of coordinate transformation.
*
* Energy expressions may involve the operators + (add), - (subtract), * (multiply), / (divide), and ^ (power), and the following
* functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, atan2, sinh, cosh, tanh, erf, erfc, min, max, abs, floor, ceil, step, delta, * functions: sqrt, exp, log, sin, cos, sec, csc, tan, cot, asin, acos, atan, atan2, sinh, cosh, tanh, erf, erfc, min, max, abs, floor, ceil, step, delta,
* select. All trigonometric functions * select. All trigonometric functions
* are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. delta(x) = 1 if x is 0, 0 otherwise. * are defined in radians, and log is the natural logarithm. step(x) = 0 if x is less than 0, 1 otherwise. delta(x) = 1 if x is 0, 0 otherwise.
...@@ -96,7 +119,7 @@ namespace OpenMM { ...@@ -96,7 +119,7 @@ namespace OpenMM {
* *
* If instead of the energy expression the ATMForce constructor specifies the values of a series of parameters, * If instead of the energy expression the ATMForce constructor specifies the values of a series of parameters,
* the default energy expression is used: * the default energy expression is used:
* *
* \verbatim embed:rst:leading-asterisk * \verbatim embed:rst:leading-asterisk
* .. code-block:: * .. code-block::
* *
...@@ -107,7 +130,7 @@ namespace OpenMM { ...@@ -107,7 +130,7 @@ namespace OpenMM {
* y = (u-Ubcore)/(Umax-Ubcore); * y = (u-Ubcore)/(Umax-Ubcore);
* u = select(step(Direction), 1, -1)*(u1-u0) * u = select(step(Direction), 1, -1)*(u1-u0)
* \endverbatim * \endverbatim
* *
* which is the same as the soft-core softplus alchemical potential energy function in the Azimi et al. paper above. * which is the same as the soft-core softplus alchemical potential energy function in the Azimi et al. paper above.
* *
* The ATMForce is then added to the System as any other Force * The ATMForce is then added to the System as any other Force
...@@ -125,18 +148,26 @@ namespace OpenMM { ...@@ -125,18 +148,26 @@ namespace OpenMM {
* In most cases, particles are only displaced in one of the two states evaluated by this force. It computes the * In most cases, particles are only displaced in one of the two states evaluated by this force. It computes the
* change in energy between the current particle coordinates (as stored in the Context) and the displaced coordinates. * change in energy between the current particle coordinates (as stored in the Context) and the displaced coordinates.
* In some cases, it is useful to apply displacements to both states. You can do this by providing two displacement * In some cases, it is useful to apply displacements to both states. You can do this by providing two displacement
* vectors to addParticle(): * vectors to the fixed displacement transformation given to addParticle(). For example, with:
* *
* \verbatim embed:rst:leading-asterisk * \verbatim embed:rst:leading-asterisk
* .. code-block:: cpp * .. code-block:: cpp
* *
* atm->addParticle(Vec3(1, 0, 0), Vec3(-1, 0, 0)); * atm->addParticle(new ATMForce::FixedDisplacement(Vec3(1, 0, 0), Vec3(-1, 0, 0)));
* \endverbatim * \endverbatim
* *
* In this case, u1 will be computed after displacing the particle in the positive x direction, and * the energy u1 will be computed after displacing the particle in the positive x direction, and
* u0 will be computed after displacing it in the negative x direction. * u0 will be computed after displacing it in the negative x direction. Similarly,
* *
* This class also has the ability to compute derivatives of the potential energy with respect to global parameters. * \verbatim embed:rst:leading-asterisk
* .. code-block:: cpp
*
* atm->addParticle(new ATMForce::ParticleOffsetDisplacement(4, 3, 2, 1));
* \endverbatim
*
* adds a particle whose position is displaced by pos[4]-pos[3] before calculating u1 and by pos[2]-pos[1] before calculating u0.
*
* The ATMForce class has the ability to compute derivatives of the potential energy with respect to global parameters.
* Call addEnergyParameterDerivative() to request that the derivative with respect to a particular parameter be * Call addEnergyParameterDerivative() to request that the derivative with respect to a particular parameter be
* computed. You can then query its value in a Context by calling getState() on it. * computed. You can then query its value in a Context by calling getState() on it.
*/ */
...@@ -144,7 +175,7 @@ namespace OpenMM { ...@@ -144,7 +175,7 @@ namespace OpenMM {
class OPENMM_EXPORT ATMForce : public OpenMM::Force { class OPENMM_EXPORT ATMForce : public OpenMM::Force {
public: public:
/** /**
* Create an ATMForce object. * Create an ATMForce object.
* *
* @param energy an algebraic expression giving the energy of the system as a function * @param energy an algebraic expression giving the energy of the system as a function
* of u0 and u1, the energies before and after displacement * of u0 and u1, the energies before and after displacement
...@@ -154,7 +185,7 @@ public: ...@@ -154,7 +185,7 @@ public:
* Create an ATMForce object with the default softplus energy expression. The values passed to * Create an ATMForce object with the default softplus energy expression. The values passed to
* this constructor are the default values of the global parameters for newly created Contexts. * this constructor are the default values of the global parameters for newly created Contexts.
* Their values can be changed by calling setParameter() on the Context using the parameter * Their values can be changed by calling setParameter() on the Context using the parameter
* names defined by the Lambda1(), Lambda2(), etc. methods below. * names defined by the Lambda1(), Lambda2(), etc. methods below.
* *
* @param lambda1 the default value of the Lambda1 parameter (dimensionless). This should be * @param lambda1 the default value of the Lambda1 parameter (dimensionless). This should be
* a number between 0 and 1. * a number between 0 and 1.
...@@ -212,8 +243,22 @@ public: ...@@ -212,8 +243,22 @@ public:
* return the force from index * return the force from index
*/ */
Force& getForce(int index) const; Force& getForce(int index) const;
/** /**
* Add a particle to the force. * Add a stationary particle: one whose coordinate is not transformed
*
* All of the particles in the System must be added to the ATMForce in the same order
* as they appear in the System.
*
* @return the index of the particle that was added
*/
int addParticle();
/**
* Add a particle to the force with fixed lab frame displacements
*
* @deprecated This method exists only for backward compatibility. Use:
* addParticle(new ATMFixedDisplacement(displacement1, displacement0))
* *
* All of the particles in the System must be added to the ATMForce in the same order * All of the particles in the System must be added to the ATMForce in the same order
* as they appear in the System. * as they appear in the System.
...@@ -223,22 +268,52 @@ public: ...@@ -223,22 +268,52 @@ public:
* @return the index of the particle that was added * @return the index of the particle that was added
*/ */
int addParticle(const Vec3& displacement1, const Vec3& displacement0=Vec3()); int addParticle(const Vec3& displacement1, const Vec3& displacement0=Vec3());
class CoordinateTransformation;
class FixedDisplacement;
class ParticleOffsetDisplacement;
/**
* Add a particle to the force with a coordinate transformation method
*
* All of the particles in the System must be added to the ATMForce in the same order
* as they appear in the System.
*
* @param transformation the pointer to the CoordinateTransformation object, which should have been
* created on the heap with the "new" operator. The ATMForce takes over
* ownership of it, and deletes the CoordinateTransformation when the ATMForce
* itself is deleted. Currently supported transformations are FixedDisplacement and
* ParticleOffsetDisplacement.
* @return the index of the particle that was added
*/
int addParticle(CoordinateTransformation* transformation);
/** /**
* Get the parameters for a particle * Get the parameters for a particle
* *
* @deprecated This method exists only for backward compatibility. Use:
* const ATMForce::CoordinateTransformation& transformation = getParticleTransformation(index);
* Vec3 displacement1 = dynamic_cast<const ATMForce::FixedDisplacement*>(&transformation)->getFixedDisplacement1();
* Vec3 displacement0 = dynamic_cast<const ATMForce::FixedDisplacement*>(&transformation)->getFixedDisplacement0();
*
* @param index the index in the force for the particle for which to get parameters * @param index the index in the force for the particle for which to get parameters
* @param displacement1 the displacement of the particle for the target state in nm * @param displacement1 the fixed lab-frame displacement of the particle for the target state in nm
* @param displacement0 the displacement of the particle for the initial state in nm * @param displacement0 the fixed lab-frame displacement of the particle for the initial state in nm
*/ */
void getParticleParameters(int index, Vec3& displacement1, Vec3& displacement0) const; void getParticleParameters(int index, Vec3& displacement1, Vec3& displacement0) const;
/** /**
* Set the parameters for a particle * Set the displacements for a particle as fixed lab frame vectors
* *
* @deprecated This method exists only for backward compatibility. Use:
* setParticleTransformation(index, new ATMForce::FixedDisplacement(displacement1, displacement0))
*
* @param index the index in the force of the particle for which to set parameters * @param index the index in the force of the particle for which to set parameters
* @param displacement1 the displacement of the particle for the target state in nm * @param displacement1 the fixed lab-frame displacement of the particle for the target state in nm
* @param displacement0 the displacement of the particle for the initial state in nm * @param displacement0 the fixed lab-frame displacement of the particle for the initial state in nm
*/ */
void setParticleParameters(int index, const Vec3& displacement1, const Vec3& displacement0=Vec3()); void setParticleParameters(int index, const Vec3& displacement1, const Vec3& displacement0=Vec3());
/** /**
* Add a new global parameter that the interaction may depend on. The default value provided to * Add a new global parameter that the interaction may depend on. The default value provided to
* this method is the initial value of the parameter in newly created Contexts. You can change * this method is the initial value of the parameter in newly created Contexts. You can change
...@@ -305,7 +380,7 @@ public: ...@@ -305,7 +380,7 @@ public:
bool usesPeriodicBoundaryConditions() const; bool usesPeriodicBoundaryConditions() const;
/** /**
* Returns the current perturbation energy. * Returns the current perturbation energy.
* *
* @param context the Context for which to return the energy * @param context the Context for which to return the energy
* @param u1 on exit, the energy of the displaced state * @param u1 on exit, the energy of the displaced state
* @param u0 on exit, the energy of the non-displaced state * @param u0 on exit, the energy of the non-displaced state
...@@ -392,6 +467,25 @@ public: ...@@ -392,6 +467,25 @@ public:
return key; return key;
} }
/**
* Change the coordinate transformation method for the specified particle
*
* @param index the index of the particle
* @param transformation the pointer to the CoordinateTransformation object, which should have been
* created on the heap with the "new" operator. The ATMForce takes over
* ownership of it, and deletes the CoordinateTransformation when the ATMForce
* itself is deleted.
*/
void setParticleTransformation(int index, CoordinateTransformation* transformation);
/**
* Returns the Transformation object associated with the particle
*
* @param index the index of the particle
* @return the CoordinateTransformation object associated with the particle
*/
const CoordinateTransformation& getParticleTransformation(int index) const;
protected: protected:
ForceImpl* createImpl() const; ForceImpl* createImpl() const;
private: private:
...@@ -411,13 +505,12 @@ private: ...@@ -411,13 +505,12 @@ private:
class ATMForce::ParticleInfo { class ATMForce::ParticleInfo {
public: public:
int index; int index;
Vec3 displacement1, displacement0; CoordinateTransformation* transformation;
ParticleInfo() : index(-1) { ParticleInfo() : index(-1), transformation(NULL) {
} }
ParticleInfo(int index) : index(index) { ParticleInfo(int index) : index(index), transformation(NULL) {
} }
ParticleInfo(int index, Vec3 displacement1, Vec3 displacement0) : ParticleInfo(int index, CoordinateTransformation* transformation) : index(index), transformation(transformation) {
index(index), displacement1(displacement1), displacement0(displacement0) {
} }
}; };
...@@ -435,6 +528,68 @@ public: ...@@ -435,6 +528,68 @@ public:
} }
}; };
/**
* The CoordinateTransformation class describes a generic coordinate transformation applied
* to a particle. It is a virtual base class. Use the derived classes FixedDisplacement and
* ParticleOffsetDisplacement to define actual coordinate transformations.
*/
class ATMForce::CoordinateTransformation {
protected:
CoordinateTransformation(){}
public:
virtual ~CoordinateTransformation(){};
};
/**
* The FixedDisplacement class describes a coordinate transformation where a particle is displaced by
* a fixed amount. To use it, create a FixedDisplacement object passing the displacement vectors for the two
* states evaluated by ATMForce. The first displacement applies to the target state, and the second
* to the reference state. The second displacement can be omitted, in which case it is set to zero.
*/
class ATMForce::FixedDisplacement : public ATMForce::CoordinateTransformation {
public:
FixedDisplacement(const Vec3& displacement1, const Vec3& displacement0=Vec3()) : displ1(displacement1), displ0(displacement0) {
}
~FixedDisplacement() override {}
const Vec3& getFixedDisplacement1() const {
return displ1;
}
const Vec3& getFixedDisplacement0() const {
return displ0;
}
private:
Vec3 displ1, displ0;
};
/**
* The ParticleOffsetDisplacement class describes a coordinate transformation in which a particle is displaced by the
* vector distance between two particles. The displacement is variable because it changes as the two particles move.
* To use it, create a ParticleOffsetDisplacement passing the indexes, pDestination1 and pOrigin1, respectively, of
* the two particles, resulting in the variable displacement pos[pDestination1]-pos[pOrigin1] if the array pos holds the
* particles' positions. Optionally, a second set of particles, pDestination0 and pOrigin0, can be specified to apply
* a similar variable displacement at the reference state of the ATMForce.
*/
class ATMForce::ParticleOffsetDisplacement : public ATMForce::CoordinateTransformation {
public:
ParticleOffsetDisplacement(int pDestination1, int pOrigin1, int pDestination0 = -1, int pOrigin0 = -1) : pDestination1(pDestination1), pOrigin1(pOrigin1), pDestination0(pDestination0), pOrigin0(pOrigin0) {
}
~ParticleOffsetDisplacement() override {}
int getDestinationParticle1() const {
return pDestination1;
}
int getOriginParticle1() const {
return pOrigin1;
}
int getDestinationParticle0() const {
return pDestination0;
}
int getOriginParticle0() const {
return pOrigin0;
}
private:
int pDestination1, pOrigin1, pDestination0, pOrigin0;
};
} // namespace OpenMM } // namespace OpenMM
#endif /*OPENMM_ATMFORCE_H_*/ #endif /*OPENMM_ATMFORCE_H_*/
...@@ -38,9 +38,11 @@ ...@@ -38,9 +38,11 @@
#include "openmm/internal/ATMForceImpl.h" #include "openmm/internal/ATMForceImpl.h"
#include "openmm/OpenMMException.h" #include "openmm/OpenMMException.h"
#include "openmm/internal/AssertionUtilities.h" #include "openmm/internal/AssertionUtilities.h"
#include <openmm/Vec3.h>
#include <iostream> #include <iostream>
#include <cmath> #include <cmath>
#include <string> #include <string>
#include <map>
using namespace OpenMM; using namespace OpenMM;
using namespace std; using namespace std;
...@@ -65,7 +67,7 @@ ATMForce::ATMForce(double lambda1, double lambda2, double alpha, double uh, doub ...@@ -65,7 +67,7 @@ ATMForce::ATMForce(double lambda1, double lambda2, double alpha, double uh, doub
"fsc = (z^Acore-1)/(z^Acore+1);" "fsc = (z^Acore-1)/(z^Acore+1);"
"z = 1 + 2*(y/Acore) + 2*(y/Acore)^2;" "z = 1 + 2*(y/Acore) + 2*(y/Acore)^2;"
"y = (u-Ubcore)/(Umax-Ubcore);" "y = (u-Ubcore)/(Umax-Ubcore);"
"u = select(step(Direction), 1, -1)*(u1-u0)"; "u = select(step(Direction), 1, -1)*(u1-u0)";
setEnergyFunction(referencePotExpression + alchemicalPotExpression + softCoreExpression); setEnergyFunction(referencePotExpression + alchemicalPotExpression + softCoreExpression);
addGlobalParameter(Lambda1(), lambda1); addGlobalParameter(Lambda1(), lambda1);
addGlobalParameter(Lambda2(), lambda2); addGlobalParameter(Lambda2(), lambda2);
...@@ -81,6 +83,10 @@ ATMForce::ATMForce(double lambda1, double lambda2, double alpha, double uh, doub ...@@ -81,6 +83,10 @@ ATMForce::ATMForce(double lambda1, double lambda2, double alpha, double uh, doub
ATMForce::~ATMForce() { ATMForce::~ATMForce() {
for (Force* force : forces) for (Force* force : forces)
delete force; delete force;
for (ParticleInfo particle : particles) {
if (particle.transformation)
delete particle.transformation;
}
} }
const string& ATMForce::getEnergyFunction() const { const string& ATMForce::getEnergyFunction() const {
...@@ -91,21 +97,49 @@ void ATMForce::setEnergyFunction(const std::string& energy) { ...@@ -91,21 +97,49 @@ void ATMForce::setEnergyFunction(const std::string& energy) {
energyExpression = energy; energyExpression = energy;
} }
int ATMForce::addParticle() {
particles.push_back(ParticleInfo(particles.size(), new ATMForce::FixedDisplacement(Vec3(0, 0, 0), Vec3(0, 0, 0))));
return particles.size()-1;
}
int ATMForce::addParticle(const Vec3& displacement1, const Vec3& displacement0) { int ATMForce::addParticle(const Vec3& displacement1, const Vec3& displacement0) {
particles.push_back(ParticleInfo(particles.size(), displacement1, displacement0)); FixedDisplacement* fd = new FixedDisplacement(displacement1, displacement0);
particles.push_back(ParticleInfo(particles.size(), fd));
return particles.size()-1;
}
int ATMForce::addParticle(ATMForce::CoordinateTransformation* transformation) {
particles.push_back(ParticleInfo(particles.size(), transformation));
return particles.size()-1; return particles.size()-1;
} }
void ATMForce::getParticleParameters(int index, Vec3& displacement1, Vec3& displacement0) const { void ATMForce::getParticleParameters(int index, Vec3& displacement1, Vec3& displacement0) const {
ASSERT_VALID_INDEX(index, particles); ASSERT_VALID_INDEX(index, particles);
displacement1 = particles[index].displacement1; CoordinateTransformation* transformation = particles[index].transformation;
displacement0 = particles[index].displacement0; const FixedDisplacement* displacement = dynamic_cast<const FixedDisplacement*>(transformation);
if (displacement == nullptr)
throw OpenMMException("getParticleParameters: the transformation for this particle is not a FixedDisplacement");
displacement1 = displacement->getFixedDisplacement1();
displacement0 = displacement->getFixedDisplacement0();
}
const ATMForce::CoordinateTransformation& ATMForce::getParticleTransformation(int index) const {
ASSERT_VALID_INDEX(index, particles);
return *(particles[index].transformation);
} }
void ATMForce::setParticleParameters(int index, const Vec3& displacement1, const Vec3& displacement0) { void ATMForce::setParticleParameters(int index, const Vec3& displacement1, const Vec3& displacement0) {
ASSERT_VALID_INDEX(index, particles); ASSERT_VALID_INDEX(index, particles);
particles[index].displacement1 = displacement1; if (particles[index].transformation)
particles[index].displacement0 = displacement0; delete particles[index].transformation;
FixedDisplacement* fd = new FixedDisplacement(displacement1, displacement0);
particles[index].transformation = fd;
}
void ATMForce::setParticleTransformation(int index, ATMForce::CoordinateTransformation* transformation) {
ASSERT_VALID_INDEX(index, particles);
if (particles[index].transformation)
delete particles[index].transformation;
particles[index].transformation = transformation;
} }
int ATMForce::addForce(Force* force) { int ATMForce::addForce(Force* force) {
......
...@@ -1299,18 +1299,25 @@ public: ...@@ -1299,18 +1299,25 @@ public:
* Get the ComputeContext corresponding to the inner Context. * Get the ComputeContext corresponding to the inner Context.
*/ */
virtual ComputeContext& getInnerComputeContext(ContextImpl& innerContext) = 0; virtual ComputeContext& getInnerComputeContext(ContextImpl& innerContext) = 0;
private: private:
class ReorderListener; class ReorderListener;
void initKernels(ContextImpl& context, ContextImpl& innerContext0, ContextImpl& innerContext1); void initKernels(ContextImpl& context, ContextImpl& innerContext0, ContextImpl& innerContext1);
void loadParams(int numParticles, const ATMForce& force, std::vector<Vec3>& d1, std::vector<Vec3>& d0, std::vector<int>& j1, std::vector<int>& i1, std::vector<int>& j0, std::vector<int>& i0);
bool hasInitializedKernel; bool hasInitializedKernel;
ComputeContext& cc; ComputeContext& cc;
ComputeArray displ1; ComputeArray displ1, displ0; // actual displacements used in calculation
ComputeArray displ0; ComputeArray displacement1, displacement0; // fixed lab-frame displacements
ComputeArray displParticles; // variable displacements based on atom positions
// int4 arranged as (pDestination1, pOrigin1, pDestination0, pOrigin0
ComputeArray invAtomOrder, inner0InvAtomOrder, inner1InvAtomOrder; ComputeArray invAtomOrder, inner0InvAtomOrder, inner1InvAtomOrder;
ComputeArray dforce0, dforce1; // forces due to variable displacements
ComputeKernel copyStateKernel; ComputeKernel copyStateKernel;
ComputeKernel setDisplacementsKernel;
ComputeKernel resetDisplForceKernel;
ComputeKernel displForceKernel;
ComputeKernel hybridForceKernel; ComputeKernel hybridForceKernel;
int numParticles; int numParticles;
......
...@@ -3783,23 +3783,82 @@ private: ...@@ -3783,23 +3783,82 @@ private:
CommonCalcATMForceKernel::~CommonCalcATMForceKernel() { CommonCalcATMForceKernel::~CommonCalcATMForceKernel() {
} }
void CommonCalcATMForceKernel::loadParams(int numParticles, const ATMForce& force, vector<Vec3>& d1, vector<Vec3>& d0, vector<int>& j1, vector<int>& i1, vector<int>& j0, vector<int>& i0) {
for (int p = 0; p < numParticles; p++) {
const ATMForce::CoordinateTransformation& transformation = force.getParticleTransformation(p);
if (dynamic_cast<const ATMForce::FixedDisplacement*>(&transformation) != NULL) {
const ATMForce::FixedDisplacement* fd = dynamic_cast<const ATMForce::FixedDisplacement*>(&transformation);
d1[p] = fd->getFixedDisplacement1();
d0[p] = fd->getFixedDisplacement0();
j1[p] = i1[p] = j0[p] = i0[p] = -1;
}
else if (dynamic_cast<const ATMForce::ParticleOffsetDisplacement*>(&transformation) != NULL) {
const ATMForce::ParticleOffsetDisplacement* vd = dynamic_cast<const ATMForce::ParticleOffsetDisplacement*>(&transformation);
d1[p] = Vec3(0, 0, 0);
d0[p] = Vec3(0, 0, 0);
j1[p] = vd->getDestinationParticle1();
i1[p] = vd->getOriginParticle1();
j0[p] = vd->getDestinationParticle0();
i0[p] = vd->getOriginParticle0();
}
else {
throw OpenMMException("loadParams(): invalid particle Transformation");
}
}
}
void CommonCalcATMForceKernel::initialize(const System& system, const ATMForce& force) { void CommonCalcATMForceKernel::initialize(const System& system, const ATMForce& force) {
ContextSelector selector(cc); ContextSelector selector(cc);
numParticles = force.getNumParticles(); numParticles = force.getNumParticles();
if (numParticles == 0) if (numParticles == 0)
return; return;
vector<mm_float4> displVector1(cc.getPaddedNumAtoms(), mm_float4(0, 0, 0, 0));
vector<mm_float4> displVector0(cc.getPaddedNumAtoms(), mm_float4(0, 0, 0, 0)); vector<int> j1(numParticles);
for (int i = 0; i < numParticles; i++) { vector<int> i1(numParticles);
Vec3 displacement1, displacement0; vector<int> j0(numParticles);
force.getParticleParameters(i, displacement1, displacement0); vector<int> i0(numParticles);
displVector1[i] = mm_float4(displacement1[0], displacement1[1], displacement1[2], 0); vector<Vec3> d1(numParticles);
displVector0[i] = mm_float4(displacement0[0], displacement0[1], displacement0[2], 0); vector<Vec3> d0(numParticles);
} loadParams(numParticles, force, d1, d0, j1, i1, j0, i0);
displ1.initialize<mm_float4>(cc, cc.getPaddedNumAtoms(), "displ1");
displ1.upload(displVector1); vector<mm_int4> displParticlesVector(cc.getPaddedNumAtoms(), mm_int4(-1, -1, -1, -1));
displ0.initialize<mm_float4>(cc, cc.getPaddedNumAtoms(), "displ0"); if (cc.getUseDoublePrecision()) {
displ0.upload(displVector0); vector<mm_double4> displVector1(cc.getPaddedNumAtoms(), mm_double4(0, 0, 0, 0));
vector<mm_double4> displVector0(cc.getPaddedNumAtoms(), mm_double4(0, 0, 0, 0));
for (int p = 0; p < numParticles; p++) {
displVector1[p] = mm_double4(d1[p][0], d1[p][1], d1[p][2], 0);
displVector0[p] = mm_double4(d0[p][0], d0[p][1], d0[p][2], 0);
displParticlesVector[p] = mm_int4(j1[p], i1[p], j0[p], i0[p]);
}
displ1.initialize<mm_double4>(cc, cc.getPaddedNumAtoms(), "displ1");
displacement1.initialize<mm_double4>(cc, cc.getPaddedNumAtoms(), "displacement1");
displacement1.upload(displVector1);
displ0.initialize<mm_double4>(cc, cc.getPaddedNumAtoms(), "displ0");
displacement0.initialize<mm_double4>(cc, cc.getPaddedNumAtoms(), "displacement0");
displacement0.upload(displVector0);
}
else {
vector<mm_float4> displVector1(cc.getPaddedNumAtoms(), mm_float4(0, 0, 0, 0));
vector<mm_float4> displVector0(cc.getPaddedNumAtoms(), mm_float4(0, 0, 0, 0));
for (int p = 0; p < numParticles; p++) {
displVector1[p] = mm_float4(d1[p][0], d1[p][1], d1[p][2], 0);
displVector0[p] = mm_float4(d0[p][0], d0[p][1], d0[p][2], 0);
displParticlesVector[p] = mm_int4(j1[p], i1[p], j0[p], i0[p]);
}
displ1.initialize<mm_float4>(cc, cc.getPaddedNumAtoms(), "displ1");
displacement1.initialize<mm_float4>(cc, cc.getPaddedNumAtoms(), "displacement1");
displacement1.upload(displVector1);
displ0.initialize<mm_float4>(cc, cc.getPaddedNumAtoms(), "displ0");
displacement0.initialize<mm_float4>(cc, cc.getPaddedNumAtoms(), "displacement0");
displacement0.upload(displVector0);
}
displParticles.initialize<mm_int4>(cc, cc.getPaddedNumAtoms(), "displParticles");
displParticles.upload(displParticlesVector);
dforce0.initialize(cc, cc.getLongForceBuffer().getSize(), cc.getLongForceBuffer().getElementSize(), "dforce0");
dforce1.initialize(cc, cc.getLongForceBuffer().getSize(), cc.getLongForceBuffer().getElementSize(), "dforce1");
invAtomOrder.initialize<int>(cc, cc.getPaddedNumAtoms(), "invAtomOrder"); invAtomOrder.initialize<int>(cc, cc.getPaddedNumAtoms(), "invAtomOrder");
inner0InvAtomOrder.initialize<int>(cc, cc.getPaddedNumAtoms(), "inner0InvAtomOrder"); inner0InvAtomOrder.initialize<int>(cc, cc.getPaddedNumAtoms(), "inner0InvAtomOrder");
inner1InvAtomOrder.initialize<int>(cc, cc.getPaddedNumAtoms(), "inner1InvAtomOrder"); inner1InvAtomOrder.initialize<int>(cc, cc.getPaddedNumAtoms(), "inner1InvAtomOrder");
...@@ -3832,8 +3891,21 @@ void CommonCalcATMForceKernel::initKernels(ContextImpl& context, ContextImpl& in ...@@ -3832,8 +3891,21 @@ void CommonCalcATMForceKernel::initKernels(ContextImpl& context, ContextImpl& in
listener0->execute(); listener0->execute();
listener1->execute(); listener1->execute();
//create CopyState kernel
ComputeProgram program = cc.compileProgram(CommonKernelSources::atmforce); ComputeProgram program = cc.compileProgram(CommonKernelSources::atmforce);
//create the setDisplacements kernel
setDisplacementsKernel = program->createKernel("setDisplacements");
setDisplacementsKernel->addArg(numParticles);
setDisplacementsKernel->addArg(cc.getPosq());
setDisplacementsKernel->addArg(displacement0);
setDisplacementsKernel->addArg(displacement1);
setDisplacementsKernel->addArg(displParticles);
setDisplacementsKernel->addArg(cc.getAtomIndexArray());
setDisplacementsKernel->addArg(invAtomOrder);
setDisplacementsKernel->addArg(displ0);
setDisplacementsKernel->addArg(displ1);
//create CopyState kernel
copyStateKernel = program->createKernel("copyState"); copyStateKernel = program->createKernel("copyState");
copyStateKernel->addArg(numParticles); copyStateKernel->addArg(numParticles);
copyStateKernel->addArg(cc.getPosq()); copyStateKernel->addArg(cc.getPosq());
...@@ -3850,6 +3922,27 @@ void CommonCalcATMForceKernel::initKernels(ContextImpl& context, ContextImpl& in ...@@ -3850,6 +3922,27 @@ void CommonCalcATMForceKernel::initKernels(ContextImpl& context, ContextImpl& in
copyStateKernel->addArg(cc1.getPosqCorrection()); copyStateKernel->addArg(cc1.getPosqCorrection());
} }
//create the resetDisplForce kernel
resetDisplForceKernel = program->createKernel("resetDisplForce");
resetDisplForceKernel->addArg(numParticles);
resetDisplForceKernel->addArg(cc.getPaddedNumAtoms());
resetDisplForceKernel->addArg(dforce0);
resetDisplForceKernel->addArg(dforce1);
//create the displForce kernel
displForceKernel = program->createKernel("displForce");
displForceKernel->addArg(numParticles);
displForceKernel->addArg(cc.getPaddedNumAtoms());
displForceKernel->addArg(cc0.getLongForceBuffer());
displForceKernel->addArg(cc1.getLongForceBuffer());
displForceKernel->addArg(dforce0);
displForceKernel->addArg(dforce1);
displForceKernel->addArg(displParticles);
displForceKernel->addArg(cc.getAtomIndexArray());
displForceKernel->addArg(invAtomOrder);
displForceKernel->addArg(inner0InvAtomOrder);
displForceKernel->addArg(inner1InvAtomOrder);
//create the HybridForce kernel //create the HybridForce kernel
hybridForceKernel = program->createKernel("hybridForce"); hybridForceKernel = program->createKernel("hybridForce");
hybridForceKernel->addArg(numParticles); hybridForceKernel->addArg(numParticles);
...@@ -3857,6 +3950,8 @@ void CommonCalcATMForceKernel::initKernels(ContextImpl& context, ContextImpl& in ...@@ -3857,6 +3950,8 @@ void CommonCalcATMForceKernel::initKernels(ContextImpl& context, ContextImpl& in
hybridForceKernel->addArg(cc.getLongForceBuffer()); hybridForceKernel->addArg(cc.getLongForceBuffer());
hybridForceKernel->addArg(cc0.getLongForceBuffer()); hybridForceKernel->addArg(cc0.getLongForceBuffer());
hybridForceKernel->addArg(cc1.getLongForceBuffer()); hybridForceKernel->addArg(cc1.getLongForceBuffer());
hybridForceKernel->addArg(dforce0);
hybridForceKernel->addArg(dforce1);
hybridForceKernel->addArg(invAtomOrder); hybridForceKernel->addArg(invAtomOrder);
hybridForceKernel->addArg(inner0InvAtomOrder); hybridForceKernel->addArg(inner0InvAtomOrder);
hybridForceKernel->addArg(inner1InvAtomOrder); hybridForceKernel->addArg(inner1InvAtomOrder);
...@@ -3873,13 +3968,15 @@ void CommonCalcATMForceKernel::applyForces(ContextImpl& context, ContextImpl& in ...@@ -3873,13 +3968,15 @@ void CommonCalcATMForceKernel::applyForces(ContextImpl& context, ContextImpl& in
double dEdu0, double dEdu1, const map<string, double>& energyParamDerivs) { double dEdu0, double dEdu1, const map<string, double>& energyParamDerivs) {
ContextSelector selector(cc); ContextSelector selector(cc);
initKernels(context, innerContext0, innerContext1); initKernels(context, innerContext0, innerContext1);
resetDisplForceKernel->execute(numParticles);
displForceKernel->execute(numParticles);
if (cc.getUseDoublePrecision()) { if (cc.getUseDoublePrecision()) {
hybridForceKernel->setArg(8, dEdu0); hybridForceKernel->setArg(10, dEdu0);
hybridForceKernel->setArg(9, dEdu1); hybridForceKernel->setArg(11, dEdu1);
} }
else { else {
hybridForceKernel->setArg(8, (float) dEdu0); hybridForceKernel->setArg(10, (float) dEdu0);
hybridForceKernel->setArg(9, (float) dEdu1); hybridForceKernel->setArg(11, (float) dEdu1);
} }
hybridForceKernel->execute(numParticles); hybridForceKernel->execute(numParticles);
map<string, double>& derivs = cc.getEnergyParamDerivWorkspace(); map<string, double>& derivs = cc.getEnergyParamDerivWorkspace();
...@@ -3905,6 +4002,8 @@ void CommonCalcATMForceKernel::copyState(ContextImpl& context, ...@@ -3905,6 +4002,8 @@ void CommonCalcATMForceKernel::copyState(ContextImpl& context,
cc0.reorderAtoms(); cc0.reorderAtoms();
cc1.reorderAtoms(); cc1.reorderAtoms();
setDisplacementsKernel->execute(numParticles);
copyStateKernel->execute(numParticles); copyStateKernel->execute(numParticles);
map<string, double> innerParameters0 = innerContext0.getParameters(); map<string, double> innerParameters0 = innerContext0.getParameters();
...@@ -3919,16 +4018,39 @@ void CommonCalcATMForceKernel::copyParametersToContext(ContextImpl& context, con ...@@ -3919,16 +4018,39 @@ void CommonCalcATMForceKernel::copyParametersToContext(ContextImpl& context, con
ContextSelector selector(cc); ContextSelector selector(cc);
if (force.getNumParticles() != numParticles) if (force.getNumParticles() != numParticles)
throw OpenMMException("copyParametersToContext: The number of ATMMetaForce particles has changed"); throw OpenMMException("copyParametersToContext: The number of ATMMetaForce particles has changed");
vector<mm_float4> displVector1(cc.getPaddedNumAtoms());
vector<mm_float4> displVector0(cc.getPaddedNumAtoms()); vector<int> j1(numParticles);
for (int i = 0; i < numParticles; i++) { vector<int> i1(numParticles);
Vec3 displacement1, displacement0; vector<int> j0(numParticles);
force.getParticleParameters(i, displacement1, displacement0); vector<int> i0(numParticles);
displVector1[i] = mm_float4(displacement1[0], displacement1[1], displacement1[2], 0); vector<Vec3> d1(numParticles);
displVector0[i] = mm_float4(displacement0[0], displacement0[1], displacement0[2], 0); vector<Vec3> d0(numParticles);
loadParams(numParticles, force, d1, d0, j1, i1, j0, i0);
vector<mm_int4> displParticlesVector(cc.getPaddedNumAtoms(), mm_int4(-1, -1, -1, -1));
if (cc.getUseDoublePrecision()) {
vector<mm_double4> displVector1(cc.getPaddedNumAtoms(), mm_double4(0, 0, 0, 0));
vector<mm_double4> displVector0(cc.getPaddedNumAtoms(), mm_double4(0, 0, 0, 0));
for (int p = 0; p < numParticles; p++) {
displVector1[p] = mm_double4(d1[p][0], d1[p][1], d1[p][2], 0);
displVector0[p] = mm_double4(d0[p][0], d0[p][1], d0[p][2], 0);
displParticlesVector[p] = mm_int4(j1[p], i1[p], j0[p], i0[p]);
}
displacement1.upload(displVector1);
displacement0.upload(displVector0);
}
else {
vector<mm_float4> displVector1(cc.getPaddedNumAtoms(), mm_float4(0, 0, 0, 0));
vector<mm_float4> displVector0(cc.getPaddedNumAtoms(), mm_float4(0, 0, 0, 0));
for (int p = 0; p < numParticles; p++) {
displVector1[p] = mm_float4(d1[p][0], d1[p][1], d1[p][2], 0);
displVector0[p] = mm_float4(d0[p][0], d0[p][1], d0[p][2], 0);
displParticlesVector[p] = mm_int4(j1[p], i1[p], j0[p], i0[p]);
}
displacement1.upload(displVector1);
displacement0.upload(displVector0);
} }
displ1.upload(displVector1); displParticles.upload(displParticlesVector);
displ0.upload(displVector0);
} }
class CommonCalcCustomCPPForceKernel::StartCalculationPreComputation : public ComputeContext::ForcePreComputation { class CommonCalcCustomCPPForceKernel::StartCalculationPreComputation : public ComputeContext::ForcePreComputation {
......
...@@ -3,6 +3,8 @@ KERNEL void hybridForce(int numParticles, ...@@ -3,6 +3,8 @@ KERNEL void hybridForce(int numParticles,
GLOBAL mm_long* RESTRICT force, GLOBAL mm_long* RESTRICT force,
GLOBAL mm_long* RESTRICT force0, GLOBAL mm_long* RESTRICT force0,
GLOBAL mm_long* RESTRICT force1, GLOBAL mm_long* RESTRICT force1,
GLOBAL mm_long* RESTRICT dforce0,
GLOBAL mm_long* RESTRICT dforce1,
GLOBAL int* RESTRICT invAtomOrder, GLOBAL int* RESTRICT invAtomOrder,
GLOBAL int* RESTRICT inner0InvAtomOrder, GLOBAL int* RESTRICT inner0InvAtomOrder,
GLOBAL int* RESTRICT inner1InvAtomOrder, GLOBAL int* RESTRICT inner1InvAtomOrder,
...@@ -12,18 +14,128 @@ KERNEL void hybridForce(int numParticles, ...@@ -12,18 +14,128 @@ KERNEL void hybridForce(int numParticles,
int index = invAtomOrder[i]; int index = invAtomOrder[i];
int index0 = inner0InvAtomOrder[i]; int index0 = inner0InvAtomOrder[i];
int index1 = inner1InvAtomOrder[i]; int index1 = inner1InvAtomOrder[i];
force[index] += (mm_long) (dEdu0*force0[index0] + dEdu1*force1[index1]); mm_long fx0 = force0[index0]+dforce0[index];
force[index+paddedNumParticles] += (mm_long) (dEdu0*force0[index0+paddedNumParticles] + dEdu1*force1[index1+paddedNumParticles]); mm_long fy0 = force0[index0+paddedNumParticles]+dforce0[index+paddedNumParticles];
force[index+paddedNumParticles*2] += (mm_long) (dEdu0*force0[index0+paddedNumParticles*2] + dEdu1*force1[index1+paddedNumParticles*2]); mm_long fz0 = force0[index0+paddedNumParticles*2]+dforce0[index+paddedNumParticles*2];
mm_long fx1 = force1[index1]+dforce1[index];
mm_long fy1 = force1[index1+paddedNumParticles]+dforce1[index+paddedNumParticles];
mm_long fz1 = force1[index1+paddedNumParticles*2]+dforce1[index+paddedNumParticles*2];
force[index] += (mm_long) (dEdu0*fx0 + dEdu1*fx1);
force[index+paddedNumParticles] += (mm_long) (dEdu0*fy0 + dEdu1*fy1);
force[index+paddedNumParticles*2] += (mm_long) (dEdu0*fz0 + dEdu1*fz1);
} }
} }
KERNEL void setDisplacements(int numParticles,
GLOBAL real4* RESTRICT posq,
GLOBAL real4* RESTRICT displacement0,
GLOBAL real4* RESTRICT displacement1,
GLOBAL int4* displParticles,
GLOBAL int* RESTRICT atomOrder,
GLOBAL int* RESTRICT invAtomOrder,
GLOBAL real4* RESTRICT displ0,
GLOBAL real4* RESTRICT displ1) {
for (int index = GLOBAL_ID; index < numParticles; index += GLOBAL_SIZE) {
int atom = atomOrder[index];
int pj1 = displParticles[atom].x;
int pi1 = displParticles[atom].y;
int pj0 = displParticles[atom].z;
int pi0 = displParticles[atom].w;
if (pj1 >= 0 && pi1 >= 0) {
// variable system coordinate displacements
int indexj1 = invAtomOrder[pj1];
int indexi1 = invAtomOrder[pi1];
displ1[atom] = make_real4((real) posq[indexj1].x- posq[indexi1].x,
(real) posq[indexj1].y- posq[indexi1].y,
(real) posq[indexj1].z- posq[indexi1].z, (real) 0);
if (pj0 >= 0 && pi0 >= 0) {
int indexj0 = invAtomOrder[pj0];
int indexi0 = invAtomOrder[pi0];
displ0[atom] = make_real4((real) posq[indexj0].x - posq[indexi0].x,
(real) posq[indexj0].y - posq[indexi0].y,
(real) posq[indexj0].z - posq[indexi0].z, (real) 0);
}
else {
displ0[atom] = make_real4((real) 0, (real) 0, (real) 0, (real) 0);
}
}
else {
//fixed lab frame displacement
displ1[atom] = displacement1[atom];
displ0[atom] = displacement0[atom];
}
}
}
//reset variable displacement forces
KERNEL void resetDisplForce(int numParticles,
int paddedNumParticles,
GLOBAL mm_ulong* RESTRICT dforce0,
GLOBAL mm_ulong* RESTRICT dforce1) {
mm_ulong zero = 0;
for (int index = GLOBAL_ID; index < numParticles; index += GLOBAL_SIZE) {
dforce0[index] = zero;
dforce0[index+paddedNumParticles] = zero;
dforce0[index+paddedNumParticles*2] = zero;
dforce1[index] = zero;
dforce1[index+paddedNumParticles] = zero;
dforce1[index+paddedNumParticles*2] = zero;
}
}
//add forces due to variable displacements
KERNEL void displForce(int numParticles,
int paddedNumParticles,
GLOBAL mm_long* RESTRICT force0,
GLOBAL mm_long* RESTRICT force1,
GLOBAL mm_long* RESTRICT dforce0,
GLOBAL mm_long* RESTRICT dforce1,
GLOBAL int4* displParticles,
GLOBAL int* RESTRICT atomOrder,
GLOBAL int* RESTRICT invAtomOrder,
GLOBAL int* RESTRICT inner0InvAtomOrder,
GLOBAL int* RESTRICT inner1InvAtomOrder) {
GLOBAL mm_ulong* df0 = (GLOBAL mm_ulong*) dforce0;
GLOBAL mm_ulong* df1 = (GLOBAL mm_ulong*) dforce1;
for (int index = GLOBAL_ID; index < numParticles; index += GLOBAL_SIZE) {
int atom = atomOrder[index];
int pj1 = displParticles[atom].x;
int pi1 = displParticles[atom].y;
int pj0 = displParticles[atom].z;
int pi0 = displParticles[atom].w;
int index0 = inner0InvAtomOrder[atom];
int index1 = inner1InvAtomOrder[atom];
if (pj1 >= 0 && pi1 >= 0) {
int j1 = invAtomOrder[pj1];
int i1 = invAtomOrder[pi1];
ATOMIC_ADD(&df1[j1], (mm_ulong) force1[index1]);
ATOMIC_ADD(&df1[j1+paddedNumParticles], (mm_ulong) force1[index1+paddedNumParticles]);
ATOMIC_ADD(&df1[j1+paddedNumParticles*2], (mm_ulong) force1[index1+paddedNumParticles*2]);
ATOMIC_ADD(&df1[i1], (mm_ulong) -force1[index1]);
ATOMIC_ADD(&df1[i1+paddedNumParticles], (mm_ulong) -force1[index1+paddedNumParticles]);
ATOMIC_ADD(&df1[i1+paddedNumParticles*2],(mm_ulong) -force1[index1+paddedNumParticles*2]);
}
if (pj0 >= 0 && pi0 >= 0) {
int j0 = invAtomOrder[pj0];
int i0 = invAtomOrder[pi0];
ATOMIC_ADD(&df0[j0], (mm_ulong) force0[index0]);
ATOMIC_ADD(&df0[j0+paddedNumParticles], (mm_ulong) force0[index0+paddedNumParticles]);
ATOMIC_ADD(&df0[j0+paddedNumParticles*2],(mm_ulong) force0[index0+paddedNumParticles*2]);
ATOMIC_ADD(&df0[i0], (mm_ulong) -force0[index0]);
ATOMIC_ADD(&df0[i0+paddedNumParticles], (mm_ulong) -force0[index0+paddedNumParticles]);
ATOMIC_ADD(&df0[i0+paddedNumParticles*2], (mm_ulong) -force0[index0+paddedNumParticles*2]);
}
}
}
KERNEL void copyState(int numParticles, KERNEL void copyState(int numParticles,
GLOBAL real4* RESTRICT posq, GLOBAL real4* RESTRICT posq,
GLOBAL real4* RESTRICT posq0, GLOBAL real4* RESTRICT posq0,
GLOBAL real4* RESTRICT posq1, GLOBAL real4* RESTRICT posq1,
GLOBAL float4* RESTRICT displ0, GLOBAL real4* RESTRICT displ0,
GLOBAL float4* RESTRICT displ1, GLOBAL real4* RESTRICT displ1,
GLOBAL int* RESTRICT atomOrder, GLOBAL int* RESTRICT atomOrder,
GLOBAL int* RESTRICT inner0InvAtomOrder, GLOBAL int* RESTRICT inner0InvAtomOrder,
GLOBAL int* RESTRICT inner1InvAtomOrder GLOBAL int* RESTRICT inner1InvAtomOrder
...@@ -50,5 +162,3 @@ KERNEL void copyState(int numParticles, ...@@ -50,5 +162,3 @@ KERNEL void copyState(int numParticles,
#endif #endif
} }
} }
...@@ -1740,7 +1740,7 @@ public: ...@@ -1740,7 +1740,7 @@ public:
} }
/** /**
* Initialize the kernel. * Initialize the kernel.
* *
* @param system the System this kernel will be applied to * @param system the System this kernel will be applied to
* @param force the ATMForce this kernel will be used for * @param force the ATMForce this kernel will be used for
*/ */
...@@ -1774,8 +1774,12 @@ public: ...@@ -1774,8 +1774,12 @@ public:
void copyState(ContextImpl& context, ContextImpl& innerContext0, ContextImpl& innerContext1); void copyState(ContextImpl& context, ContextImpl& innerContext0, ContextImpl& innerContext1);
private: private:
int numParticles; int numParticles;
std::vector<Vec3> displ1; std::vector<Vec3> displ1, displ0;
std::vector<Vec3> displ0; std::vector<Vec3> displacement1, displacement0;
std::vector<int> pj1, pi1, pj0, pi0;
void setDisplacements(std::vector<Vec3>& pos);
void displForces(std::vector<Vec3>& force0, std::vector<Vec3>& force1);
void loadParams(int numParticles, const ATMForce& force);
}; };
/** /**
......
...@@ -3053,17 +3053,90 @@ void ReferenceRemoveCMMotionKernel::execute(ContextImpl& context) { ...@@ -3053,17 +3053,90 @@ void ReferenceRemoveCMMotionKernel::execute(ContextImpl& context) {
} }
} }
void ReferenceCalcATMForceKernel::loadParams(int numParticles, const ATMForce& force) {
//vector displacements
displacement1.resize(numParticles);
displacement0.resize(numParticles);
//particle distance displacements
pj1.resize(numParticles);
pi1.resize(numParticles);
pj0.resize(numParticles);
pi0.resize(numParticles);
for (int i = 0; i < numParticles; i++) {
const ATMForce::CoordinateTransformation& transformation = force.getParticleTransformation(i);
if (dynamic_cast<const ATMForce::FixedDisplacement*>(&transformation) != NULL) {
const ATMForce::FixedDisplacement* fd = dynamic_cast<const ATMForce::FixedDisplacement*>(&transformation);
const Vec3 d1 = fd->getFixedDisplacement1();
const Vec3 d0 = fd->getFixedDisplacement0();
displacement1[i] = d1;
displacement0[i] = d0;
pj1[i] = pi1[i] = pj0[i] = pi0[i] = -1;
}
else if (dynamic_cast<const ATMForce::ParticleOffsetDisplacement*>(&transformation) != NULL) {
const ATMForce::ParticleOffsetDisplacement* vd = dynamic_cast<const ATMForce::ParticleOffsetDisplacement*>(&transformation);
displacement1[i] = Vec3(0, 0, 0);
displacement0[i] = Vec3(0, 0, 0);
pj1[i] = vd->getDestinationParticle1();
pi1[i] = vd->getOriginParticle1();
pj0[i] = vd->getDestinationParticle0();
pi0[i] = vd->getOriginParticle0();
}
else {
throw OpenMMException("loadParams(): invalid particle Transformation");
}
}
}
void ReferenceCalcATMForceKernel::initialize(const System& system, const ATMForce& force) { void ReferenceCalcATMForceKernel::initialize(const System& system, const ATMForce& force) {
numParticles = force.getNumParticles(); numParticles = force.getNumParticles();
//displacement map //displacement map
displ1.resize(numParticles); displ1.resize(numParticles);
displ0.resize(numParticles); displ0.resize(numParticles);
for (int i = 0; i < numParticles; i++) { //load particle parameters from the force object
Vec3 displacement1, displacement0; loadParams(numParticles, force);
force.getParticleParameters(i, displacement1, displacement0 ); }
displ1[i] = displacement1;
displ0[i] = displacement0; void ReferenceCalcATMForceKernel::setDisplacements(vector<Vec3>& pos){
numParticles = pos.size();
for (int i = 0; i < numParticles; i++) {
if (pj1[i] >= 0 && pi1[i] >= 0){
displ1[i] = pos[pj1[i]] - pos[pi1[i]];
if (pi0[i] >= 0 && pj0[i] >= 0){
displ0[i] = pos[pj0[i]] - pos[pi0[i]];
}else{
displ0[i] = Vec3();
}
}else{
displ1[i] = displacement1[i];
displ0[i] = displacement0[i];
}
}
}
//Add forces from variable displacements
void ReferenceCalcATMForceKernel::displForces(vector<Vec3>& force0, vector<Vec3>& force1){
vector<Vec3> dforce1(numParticles), dforce0(numParticles);
for (int i = 0; i < numParticles; i++){
if (pj1[i] >= 0 && pi1[i] >= 0){
dforce1[pj1[i]] += force1[i];
dforce1[pi1[i]] -= force1[i];
}
}
for (int i = 0; i < numParticles; i++){
force1[i] += dforce1[i];
}
for (int i = 0; i < numParticles; i++){
if (pj0[i] >= 0 && pi0[i] >= 0){
dforce0[pj0[i]] += force0[i];
dforce0[pi0[i]] -= force0[i];
}
}
for (int i = 0; i < numParticles; i++){
force0[i] += dforce0[i];
} }
} }
...@@ -3073,15 +3146,17 @@ void ReferenceCalcATMForceKernel::applyForces(ContextImpl& context, ContextImpl& ...@@ -3073,15 +3146,17 @@ void ReferenceCalcATMForceKernel::applyForces(ContextImpl& context, ContextImpl&
vector<Vec3>& force0 = extractForces(innerContext0); vector<Vec3>& force0 = extractForces(innerContext0);
vector<Vec3>& force1 = extractForces(innerContext1); vector<Vec3>& force1 = extractForces(innerContext1);
//update forces and //add gradients from variable displacements
displForces(force0, force1);
//protects from infinite forces when the hybrid potential does //protects from infinite forces when the hybrid potential does
//not depend on u1 or u0, typically at the endpoints //not depend on u1 or u0, typically at the endpoints
double epsi = std::numeric_limits<float>::min(); double epsi = std::numeric_limits<float>::min();
for (int i = 0; i < force.size(); i++) { for (int i = 0; i < force.size(); i++) {
if (fabs(dEdu0) > epsi) if (fabs(dEdu0) > epsi)
force[i] += dEdu0*force0[i]; force[i] += dEdu0*force0[i];
if (fabs(dEdu1) > epsi) if (fabs(dEdu1) > epsi)
force[i] += dEdu1*force1[i]; force[i] += dEdu1*force1[i];
} }
map<string, double>& derivs = extractEnergyParameterDerivatives(context); map<string, double>& derivs = extractEnergyParameterDerivatives(context);
...@@ -3092,6 +3167,9 @@ void ReferenceCalcATMForceKernel::applyForces(ContextImpl& context, ContextImpl& ...@@ -3092,6 +3167,9 @@ void ReferenceCalcATMForceKernel::applyForces(ContextImpl& context, ContextImpl&
void ReferenceCalcATMForceKernel::copyState(ContextImpl& context, ContextImpl& innerContext0, ContextImpl& innerContext1) { void ReferenceCalcATMForceKernel::copyState(ContextImpl& context, ContextImpl& innerContext0, ContextImpl& innerContext1) {
vector<Vec3>& pos = extractPositions(context); vector<Vec3>& pos = extractPositions(context);
//calculate displacement vectors
setDisplacements(pos);
//in the initial state, particles are displaced by displ0 //in the initial state, particles are displaced by displ0
vector<Vec3> pos0(pos); vector<Vec3> pos0(pos);
for (int i = 0; i < pos0.size(); i++) for (int i = 0; i < pos0.size(); i++)
...@@ -3129,12 +3207,7 @@ void ReferenceCalcATMForceKernel::copyParametersToContext(ContextImpl& context, ...@@ -3129,12 +3207,7 @@ void ReferenceCalcATMForceKernel::copyParametersToContext(ContextImpl& context,
throw OpenMMException("copyParametersToContext: The number of ATMForce particles has changed"); throw OpenMMException("copyParametersToContext: The number of ATMForce particles has changed");
displ1.resize(numParticles); displ1.resize(numParticles);
displ0.resize(numParticles); displ0.resize(numParticles);
for (int i = 0; i < numParticles; i++) { loadParams(numParticles, force);
Vec3 displacement1, displacement0;
force.getParticleParameters(i, displacement1, displacement0 );
displ1[i] = displacement1;
displ0[i] = displacement0;
}
} }
void ReferenceCalcCustomCPPForceKernel::initialize(const System& system, CustomCPPForceImpl& force) { void ReferenceCalcCustomCPPForceKernel::initialize(const System& system, CustomCPPForceImpl& force) {
......
...@@ -34,6 +34,9 @@ ...@@ -34,6 +34,9 @@
#include "openmm/internal/windowsExport.h" #include "openmm/internal/windowsExport.h"
#include "openmm/serialization/SerializationProxy.h" #include "openmm/serialization/SerializationProxy.h"
#include "openmm/ATMForce.h"
#include "openmm/Vec3.h"
#include <vector>
namespace OpenMM { namespace OpenMM {
...@@ -46,6 +49,8 @@ namespace OpenMM { ...@@ -46,6 +49,8 @@ namespace OpenMM {
ATMForceProxy(); ATMForceProxy();
void serialize(const void* object, SerializationNode& node) const; void serialize(const void* object, SerializationNode& node) const;
void* deserialize(const SerializationNode& node) const; void* deserialize(const SerializationNode& node) const;
private:
void storeParams(int numParticles, ATMForce& force, const SerializationNode& particles) const ;
}; };
} // namespace OpenMM } // namespace OpenMM
......
...@@ -33,6 +33,8 @@ ...@@ -33,6 +33,8 @@
#include "openmm/serialization/SerializationNode.h" #include "openmm/serialization/SerializationNode.h"
#include "openmm/Force.h" #include "openmm/Force.h"
#include "openmm/ATMForce.h" #include "openmm/ATMForce.h"
#include "openmm/OpenMMException.h"
#include <vector>
#include <sstream> #include <sstream>
using namespace OpenMM; using namespace OpenMM;
...@@ -41,6 +43,36 @@ using namespace std; ...@@ -41,6 +43,36 @@ using namespace std;
ATMForceProxy::ATMForceProxy() : SerializationProxy("ATMForce") { ATMForceProxy::ATMForceProxy() : SerializationProxy("ATMForce") {
} }
void ATMForceProxy::storeParams(int numParticles, ATMForce& force, const SerializationNode& particles) const {
for (auto& p : particles.getChildren()){
//support older serialized ATMForce instances that did not store the transformation type
//or the particle offset displacement indexes
if (p.hasProperty("type")) {
//normal behavior
string type = p.getStringProperty("type");
if (type == "fixed") {
force.addParticle(new ATMForce::FixedDisplacement(Vec3(p.getDoubleProperty("d1x"), p.getDoubleProperty("d1y"), p.getDoubleProperty("d1z")), Vec3(p.getDoubleProperty("d0x"), p.getDoubleProperty("d0y"), p.getDoubleProperty("d0z"))));
}
else if (type == "offset") {
force.addParticle(new ATMForce::ParticleOffsetDisplacement(p.getIntProperty("pj1"), p.getIntProperty("pi1"), p.getIntProperty("pj0"), p.getIntProperty("pi0")));
}
else {
throw OpenMMException("storeParams(): invalid particle transformation type");
}
}
else if (p.hasProperty("pj1") && p.getIntProperty("pj1") >= 0) {
//missing type, but particle offset indexes are present
force.addParticle(new ATMForce::ParticleOffsetDisplacement(p.getIntProperty("pj1"), p.getIntProperty("pi1"), p.getIntProperty("pj0"), p.getIntProperty("pi0")));
}
else {
//only displacements are present
force.addParticle(new ATMForce::FixedDisplacement(Vec3(p.getDoubleProperty("d1x"), p.getDoubleProperty("d1y"), p.getDoubleProperty("d1z")), Vec3(p.getDoubleProperty("d0x"), p.getDoubleProperty("d0y"), p.getDoubleProperty("d0z"))));
}
}
}
void ATMForceProxy::serialize(const void* object, SerializationNode& node) const { void ATMForceProxy::serialize(const void* object, SerializationNode& node) const {
node.setIntProperty("version", 0); node.setIntProperty("version", 0);
const ATMForce& force = *reinterpret_cast<const ATMForce*> (object); const ATMForce& force = *reinterpret_cast<const ATMForce*> (object);
...@@ -61,11 +93,37 @@ void ATMForceProxy::serialize(const void* object, SerializationNode& node) const ...@@ -61,11 +93,37 @@ void ATMForceProxy::serialize(const void* object, SerializationNode& node) const
f.createChildNode("Force", &force.getForce(i)); f.createChildNode("Force", &force.getForce(i));
} }
SerializationNode& particles = node.createChildNode("Particles"); SerializationNode& particles = node.createChildNode("Particles");
int numParticles = force.getNumParticles();
string type;
int j1, i1, j0, i0;
Vec3 d1, d0;
for (int i = 0; i < force.getNumParticles(); i++) { for (int i = 0; i < force.getNumParticles(); i++) {
Vec3 d1, d0; const ATMForce::CoordinateTransformation& transformation = force.getParticleTransformation(i);
force.getParticleParameters(i, d1, d0); if (dynamic_cast<const ATMForce::FixedDisplacement*>(&transformation) != nullptr) {
particles.createChildNode("Particle").setDoubleProperty("d1x", d1[0]).setDoubleProperty("d1y", d1[1]).setDoubleProperty("d1z", d1[2]) const ATMForce::FixedDisplacement* fd = dynamic_cast<const ATMForce::FixedDisplacement*>(&transformation);
.setDoubleProperty("d0x", d0[0]).setDoubleProperty("d0y", d0[1]).setDoubleProperty("d0z", d0[2]); d1 = fd->getFixedDisplacement1();
d0 = fd->getFixedDisplacement0();
j1 = i1 = j0 = i0 = -1;
type = "fixed";
}
else if (dynamic_cast<const ATMForce::ParticleOffsetDisplacement*>(&transformation) != nullptr) {
const ATMForce::ParticleOffsetDisplacement* vd = dynamic_cast<const ATMForce::ParticleOffsetDisplacement*>(&transformation);
d1 = Vec3(0, 0, 0);
d0 = Vec3(0, 0, 0);
j1 = vd->getDestinationParticle1();
i1 = vd->getOriginParticle1();
j0 = vd->getDestinationParticle0();
i0 = vd->getOriginParticle0();
type = "offset";
}
else {
throw OpenMMException("serialize(): invalid particle CoordinateTransformation");
}
particles.createChildNode("Particle").setStringProperty("type", type)
.setDoubleProperty("d1x", d1[0]).setDoubleProperty("d1y", d1[1]).setDoubleProperty("d1z", d1[2])
.setDoubleProperty("d0x", d0[0]).setDoubleProperty("d0y", d0[1]).setDoubleProperty("d0z", d0[2])
.setIntProperty("pj1", j1).setIntProperty("pi1", i1).setIntProperty("pj0", j0).setIntProperty("pi0", i0);
} }
} }
...@@ -73,7 +131,7 @@ void* ATMForceProxy::deserialize(const SerializationNode& node) const { ...@@ -73,7 +131,7 @@ void* ATMForceProxy::deserialize(const SerializationNode& node) const {
int version = node.getIntProperty("version"); int version = node.getIntProperty("version");
if (version != 0) if (version != 0)
throw OpenMMException("Unsupported version number"); throw OpenMMException("Unsupported version number");
ATMForce* force = NULL; ATMForce* force = nullptr;
try { try {
ATMForce* force = new ATMForce(node.getStringProperty("energy")); ATMForce* force = new ATMForce(node.getStringProperty("energy"));
force->setForceGroup(node.getIntProperty("forceGroup", 0)); force->setForceGroup(node.getIntProperty("forceGroup", 0));
...@@ -88,9 +146,9 @@ void* ATMForceProxy::deserialize(const SerializationNode& node) const { ...@@ -88,9 +146,9 @@ void* ATMForceProxy::deserialize(const SerializationNode& node) const {
for (auto& f : forces.getChildren()) for (auto& f : forces.getChildren())
force->addForce(f.getChildren()[0].decodeObject<Force>()); force->addForce(f.getChildren()[0].decodeObject<Force>());
const SerializationNode& particles = node.getChildNode("Particles"); const SerializationNode& particles = node.getChildNode("Particles");
for (auto& p : particles.getChildren())
force->addParticle(Vec3(p.getDoubleProperty("d1x"), p.getDoubleProperty("d1y"), p.getDoubleProperty("d1z")), storeParams(force->getNumParticles(), *force, particles);
Vec3(p.getDoubleProperty("d0x"), p.getDoubleProperty("d0y"), p.getDoubleProperty("d0z")));
return force; return force;
} }
catch (...) { catch (...) {
......
...@@ -55,8 +55,8 @@ void testSerialization() { ...@@ -55,8 +55,8 @@ void testSerialization() {
HarmonicAngleForce* v2 = new HarmonicAngleForce(); HarmonicAngleForce* v2 = new HarmonicAngleForce();
v2->addAngle(3, 11, 15, 0.4, 0.2); v2->addAngle(3, 11, 15, 0.4, 0.2);
force.addForce(v2); force.addForce(v2);
force.addParticle(Vec3(1, 2, 3)); force.addParticle(new ATMForce::FixedDisplacement(Vec3(1, 2, 3)));
force.addParticle(Vec3(0, 0, -1), Vec3(3, 2, 1)); force.addParticle(new ATMForce::ParticleOffsetDisplacement(0, 1));
// Serialize and then deserialize it. // Serialize and then deserialize it.
...@@ -87,11 +87,38 @@ void testSerialization() { ...@@ -87,11 +87,38 @@ void testSerialization() {
} }
ASSERT_EQUAL(force.getNumParticles(), force2.getNumParticles()); ASSERT_EQUAL(force.getNumParticles(), force2.getNumParticles());
for (int i = 0; i < force.getNumParticles(); i++) { for (int i = 0; i < force.getNumParticles(); i++) {
Vec3 d1a, d1b, d0a, d0b; const ATMForce::CoordinateTransformation& transformation = force.getParticleTransformation(i);
force.getParticleParameters(i, d1a, d0a); const ATMForce::CoordinateTransformation& transformation2 = force2.getParticleTransformation(i);
force2.getParticleParameters(i, d1b, d0b); if (dynamic_cast<const ATMForce::FixedDisplacement*>(&transformation) != nullptr) {
ASSERT_EQUAL_VEC(d1a, d1b, 0.0); const ATMForce::FixedDisplacement* fd = dynamic_cast<const ATMForce::FixedDisplacement*>(&transformation);
ASSERT_EQUAL_VEC(d0a, d0b, 0.0); const ATMForce::FixedDisplacement* fd2 = dynamic_cast<const ATMForce::FixedDisplacement*>(&transformation2);
const Vec3 d1a = fd->getFixedDisplacement1();
const Vec3 d0a = fd->getFixedDisplacement0();
const Vec3 d1b = fd2->getFixedDisplacement1();
const Vec3 d0b = fd2->getFixedDisplacement0();
ASSERT_EQUAL_VEC(d1a, d1b, 0.0);
ASSERT_EQUAL_VEC(d0a, d0b, 0.0);
}
else if (dynamic_cast<const ATMForce::ParticleOffsetDisplacement*>(&transformation) != nullptr) {
const ATMForce::ParticleOffsetDisplacement* vd = dynamic_cast<const ATMForce::ParticleOffsetDisplacement*>(&transformation);
const ATMForce::ParticleOffsetDisplacement* vd2 = dynamic_cast<const ATMForce::ParticleOffsetDisplacement*>(&transformation2);
int j1a = vd->getDestinationParticle1();
int i1a = vd->getOriginParticle1();
int j0a = vd->getDestinationParticle0();
int i0a = vd->getOriginParticle0();
int j1b = vd2->getDestinationParticle1();
int i1b = vd2->getOriginParticle1();
int j0b = vd2->getDestinationParticle0();
int i0b = vd2->getOriginParticle0();
ASSERT_EQUAL(j1a, j1b);
ASSERT_EQUAL(i1a, i1b);
ASSERT_EQUAL(j0a, j0b);
ASSERT_EQUAL(i0a, i0b);
}
else {
throwException(__FILE__, __LINE__, "Unknown CoordinateTransformation type");
}
} }
} }
......
...@@ -50,10 +50,44 @@ ...@@ -50,10 +50,44 @@
#include <random> #include <random>
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include <string>
using namespace OpenMM; using namespace OpenMM;
using namespace std; using namespace std;
void testAPI(){
ATMForce* atm = new ATMForce(0.1, 0.2, 0.3, 0.4, 0.5, 0.7, 0.6, 0.8, -1.0);
atm->addParticle(Vec3(1, 2, 3), Vec3(4, 5, 6)); //old interface
atm->addParticle(new ATMForce::FixedDisplacement(Vec3(7, 8, 9), Vec3(10, 11, 12)));
atm->addParticle(new ATMForce::ParticleOffsetDisplacement(1, 0));
atm->addParticle();
Vec3 d1, d0;
atm->getParticleParameters(0, d1, d0);
ASSERT_EQUAL_VEC(Vec3(1, 2, 3), d1, 1e-6);
ASSERT_EQUAL_VEC(Vec3(4, 5, 6), d0, 1e-6);
const ATMForce::FixedDisplacement* fd = (dynamic_cast<const ATMForce::FixedDisplacement*>(&(atm->getParticleTransformation(1))));
d1 = fd->getFixedDisplacement1();
d0 = fd->getFixedDisplacement0();
ASSERT_EQUAL_VEC(Vec3(7, 8, 9), d1, 1e-6);
ASSERT_EQUAL_VEC(Vec3(10, 11, 12), d0, 1e-6);
const ATMForce::ParticleOffsetDisplacement* vt = (dynamic_cast<const ATMForce::ParticleOffsetDisplacement*>(&(atm->getParticleTransformation(2))));
int j1 = vt->getDestinationParticle1();
int i1 = vt->getOriginParticle1();
int j0 = vt->getDestinationParticle0();
int i0 = vt->getOriginParticle0();
ASSERT_EQUAL( 1, j1);
ASSERT_EQUAL( 0, i1);
ASSERT_EQUAL(-1, j0);
ASSERT_EQUAL(-1, i0);
atm->getParticleParameters(3, d1, d0);
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), d1, 1e-6);
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), d0, 1e-6);
}
void test2Particles() { void test2Particles() {
// A pair of particles tethered by an harmonic bond. // A pair of particles tethered by an harmonic bond.
// Displace the second one to test energy and forces at different lambda values // Displace the second one to test energy and forces at different lambda values
...@@ -76,10 +110,9 @@ void test2Particles() { ...@@ -76,10 +110,9 @@ void test2Particles() {
bond->addBond(0, 1); bond->addBond(0, 1);
ATMForce* atm = new ATMForce(lmbd, lmbd, 0., 0, 0, umax, ubcore, acore, direction); ATMForce* atm = new ATMForce(lmbd, lmbd, 0., 0, 0, umax, ubcore, acore, direction);
Vec3 nodispl = Vec3(0., 0., 0.); Vec3 displ = Vec3(1., 0., 0.);
Vec3 displ = Vec3(1., 0., 0.); atm->addParticle();
atm->addParticle( nodispl ); atm->addParticle(new ATMForce::FixedDisplacement(displ));
atm->addParticle( displ );
atm->addForce(bond); atm->addForce(bond);
atm->addEnergyParameterDerivative(ATMForce::Lambda1()); atm->addEnergyParameterDerivative(ATMForce::Lambda1());
atm->addEnergyParameterDerivative(ATMForce::Lambda2()); atm->addEnergyParameterDerivative(ATMForce::Lambda2());
...@@ -110,12 +143,78 @@ void test2Particles() { ...@@ -110,12 +143,78 @@ void test2Particles() {
} }
} }
void test3ParticlesSwap() {
// A pair of particles tethered by harmonic bonds to a central particle.
// Swap the pair and test energy and forces at different lambda values
System system;
system.addParticle(1.0);
system.addParticle(1.0);
system.addParticle(1.0);
double lmbd = 0.5;
double umax = 0.;
double ubcore= 0.;
double acore = 0.;
double direction = 1.0;
Vec3 origin = Vec3(0., 0., 0.);
Vec3 r1 = Vec3(1., 0., 0.);
double r1sq = r1[0]*r1[0] + r1[1]*r1[1] + r1[2]*r1[2];
Vec3 r2 = Vec3(-2., 0., 0.);
double r2sq = r2[0]*r2[0] + r2[1]*r2[1] + r2[2]*r2[2];
vector<Vec3> positions(3);
positions[0] = origin;
positions[1] = r1;
positions[2] = r2;
CustomBondForce* bond = new CustomBondForce("0.5*kf*r^2");
double kf1 = 0.31;
double kf2 = 0.17;
bond->addPerBondParameter("kf");
std::vector<double> kf1v = {kf1};
bond->addBond(0, 1, kf1v);
std::vector<double> kf2v = {kf2};
bond->addBond(0, 2, kf2v);
ATMForce* atm = new ATMForce(lmbd, lmbd, 0., 0, 0, umax, ubcore, acore, direction);
//swap particles 1 and 2
atm->addParticle( ); //particle 0 is not displaced
atm->addParticle(new ATMForce::ParticleOffsetDisplacement(2, 1) );
atm->addParticle(new ATMForce::ParticleOffsetDisplacement(1, 2) );
atm->addForce(bond);
system.addForce(atm);
VerletIntegrator integrator(1.0);
Context context(system, integrator, platform);
context.setPositions(positions);
for (double lm : {0.0, 0.5, 1.0}) {
context.setParameter(ATMForce::Lambda1(), lm);
context.setParameter(ATMForce::Lambda2(), lm);
State state = context.getState(State::Energy | State::Forces );
double epot = state.getPotentialEnergy();
double u0, u1, energy;
atm->getPerturbationEnergy(context, u1, u0, energy);
double epert = u1 - u0;
ASSERT_EQUAL_TOL(energy, epot, 1e-6);
ASSERT_EQUAL_TOL(0.5*kf1*r1sq + 0.5*kf2*r2sq, u0, 1e-6);
ASSERT_EQUAL_TOL(0.5*kf1*r2sq + 0.5*kf2*r1sq, u1, 1e-6);
ASSERT_EQUAL_TOL(0.5*kf1*(r2sq-r1sq) + 0.5*kf2*(r1sq-r2sq), epert, 1e-6);
ASSERT_EQUAL_TOL(u0 + lm*epert, epot, 1e-6);
ASSERT_EQUAL_VEC(- ( ((1.-lm)*kf1+lm*kf2)*r1 ), state.getForces()[1], 1e-6);
ASSERT_EQUAL_VEC(- ( ((1.-lm)*kf2+lm*kf1)*r2 ), state.getForces()[2], 1e-6);
}
}
void test2Particles2Displacement0() { void test2Particles2Displacement0() {
// A pair of particles tethered by an harmonic bond. // A pair of particles tethered by an harmonic bond.
// Displace the second one to test energy and forces at different lambda values // Displace the second one to test energy and forces at different lambda values
// In this version the second particle is displaced in both the initial and final states // In this version the second particle is displaced in both the initial and final states
// by different amounts. // by different amounts.
System system; System system;
system.addParticle(1.0); system.addParticle(1.0);
system.addParticle(1.0); system.addParticle(1.0);
...@@ -132,14 +231,13 @@ void test2Particles2Displacement0() { ...@@ -132,14 +231,13 @@ void test2Particles2Displacement0() {
CustomBondForce* bond = new CustomBondForce("0.5*r^2"); CustomBondForce* bond = new CustomBondForce("0.5*r^2");
bond->addBond(0, 1); bond->addBond(0, 1);
ATMForce* atm = new ATMForce(lmbd, lmbd, 0., 0., 0., umax, ubcore, acore, direction); ATMForce* atm = new ATMForce(lmbd, lmbd, 0., 0., 0., umax, ubcore, acore, direction);
//first particle is not displaced at either state //first particle is not displaced at either state
Vec3 nodispl = Vec3(0., 0., 0.); atm->addParticle();
atm->addParticle( nodispl );
//second particle is displaced at both states but by the same amount (1,0,0) //second particle is displaced at both states but by the same amount (1,0,0)
Vec3 displ0 = Vec3(1., 0., 0.); Vec3 displ0 = Vec3(1., 0., 0.);
atm->addParticle( displ0, displ0 ); atm->addParticle(new ATMForce::FixedDisplacement(displ0, displ0));
atm->addForce(bond); atm->addForce(bond);
system.addForce(atm); system.addForce(atm);
...@@ -150,7 +248,7 @@ void test2Particles2Displacement0() { ...@@ -150,7 +248,7 @@ void test2Particles2Displacement0() {
State state; State state;
double epot, epert; double epot, epert;
double u0, u1, energy; double u0, u1, energy;
// U = U0 + lambda*epert; epert = U1 - U0 // U = U0 + lambda*epert; epert = U1 - U0
// When the second particle is displaced by the same amount at each state, // When the second particle is displaced by the same amount at each state,
...@@ -215,10 +313,9 @@ void test2ParticlesSoftCore() { ...@@ -215,10 +313,9 @@ void test2ParticlesSoftCore() {
bond->addBond(0, 1); bond->addBond(0, 1);
ATMForce* atm = new ATMForce(lmbd, lmbd, 0., 0, 0, umax, ubcore, acore, direction); ATMForce* atm = new ATMForce(lmbd, lmbd, 0., 0, 0, umax, ubcore, acore, direction);
Vec3 nodispl = Vec3(0., 0., 0.);
Vec3 displ = Vec3(5., 0., 0.); Vec3 displ = Vec3(5., 0., 0.);
atm->addParticle( nodispl ); atm->addParticle();
atm->addParticle( displ ); atm->addParticle(new ATMForce::FixedDisplacement(displ));
atm->addForce(bond); atm->addForce(bond);
system.addForce(atm); system.addForce(atm);
...@@ -261,10 +358,10 @@ void testNonbonded() { ...@@ -261,10 +358,10 @@ void testNonbonded() {
for (int i = 0; i < 6; i++) for (int i = 0; i < 6; i++)
for (int j = 0; j < 6; j++) for (int j = 0; j < 6; j++)
for (int k = 0; k < 6; k++) { for (int k = 0; k < 6; k++) {
positions.push_back(Vec3(spacing*i+offset, spacing*j+offset, spacing*k+offset)); positions.push_back(Vec3(spacing*i+offset, spacing*j+offset, spacing*k+offset));
system.addParticle(10.0); system.addParticle(10.0);
nbforce->addParticle(0, 0.3, 1.0); nbforce->addParticle(0, 0.3, 1.0);
atm->addParticle(Vec3()); atm->addParticle();
} }
auto rng = std::default_random_engine {}; auto rng = std::default_random_engine {};
std::shuffle(std::begin(positions), std::end(positions), rng); std::shuffle(std::begin(positions), std::end(positions), rng);
...@@ -288,7 +385,7 @@ void testNonbonded() { ...@@ -288,7 +385,7 @@ void testNonbonded() {
atm->getPerturbationEnergy(context1, u1, u0, energy); atm->getPerturbationEnergy(context1, u1, u0, energy);
double epert1 = u1 - u0; double epert1 = u1 - u0;
//in this second scenario the non-bonded force is remove from the System //in this second scenario the non-bonded force is removed from the System
system.removeForce(0); system.removeForce(0);
LangevinMiddleIntegrator integrator2(300, 1.0, 0.004); LangevinMiddleIntegrator integrator2(300, 1.0, 0.004);
Context context2(system, integrator2, platform); Context context2(system, integrator2, platform);
...@@ -324,10 +421,10 @@ void testNonbondedwithEndpointClash() { ...@@ -324,10 +421,10 @@ void testNonbondedwithEndpointClash() {
for (int i = 0; i < 6; i++) for (int i = 0; i < 6; i++)
for (int j = 0; j < 6; j++) for (int j = 0; j < 6; j++)
for (int k = 0; k < 6; k++) { for (int k = 0; k < 6; k++) {
positions.push_back(Vec3(spacing*i+offset, spacing*j+offset, spacing*k+offset)); positions.push_back(Vec3(spacing*i+offset, spacing*j+offset, spacing*k+offset));
system.addParticle(10.0); system.addParticle(10.0);
nbforce->addParticle(0, 0.3, 1.0); nbforce->addParticle(0, 0.3, 1.0);
atm->addParticle(Vec3(0,0,0)); atm->addParticle();
} }
//places first particle almost on top of another particle in displaced system //places first particle almost on top of another particle in displaced system
atm->setParticleParameters(0, Vec3(spacing+1.e-4, 0, 0), Vec3(0.0, 0, 0)); atm->setParticleParameters(0, Vec3(spacing+1.e-4, 0, 0), Vec3(0.0, 0, 0));
...@@ -371,10 +468,9 @@ void testParticlesCustomExpressionLinear() { ...@@ -371,10 +468,9 @@ void testParticlesCustomExpressionLinear() {
double lmbd = 0.5; double lmbd = 0.5;
ATMForce* atm = new ATMForce("u0 + Lambda*(u1 - u0)"); ATMForce* atm = new ATMForce("u0 + Lambda*(u1 - u0)");
atm->addGlobalParameter("Lambda", lmbd); atm->addGlobalParameter("Lambda", lmbd);
Vec3 nodispl = Vec3(0., 0., 0.);
Vec3 displ = Vec3(5., 0., 0.); Vec3 displ = Vec3(5., 0., 0.);
atm->addParticle( nodispl ); atm->addParticle();
atm->addParticle( displ ); atm->addParticle(new ATMForce::FixedDisplacement(displ));
atm->addForce(bond); atm->addForce(bond);
system.addForce(atm); system.addForce(atm);
...@@ -407,7 +503,6 @@ void testParticlesCustomExpressionSoftplus() { ...@@ -407,7 +503,6 @@ void testParticlesCustomExpressionSoftplus() {
positions[0] = Vec3(0, 0, 0); positions[0] = Vec3(0, 0, 0);
positions[1] = Vec3(0, 0, 0); positions[1] = Vec3(0, 0, 0);
Vec3 nodispl = Vec3(0., 0., 0.);
Vec3 displ = Vec3(2., 0., 0.); Vec3 displ = Vec3(2., 0., 0.);
CustomBondForce* bond = new CustomBondForce("0.5*r^2"); CustomBondForce* bond = new CustomBondForce("0.5*r^2");
...@@ -426,8 +521,8 @@ void testParticlesCustomExpressionSoftplus() { ...@@ -426,8 +521,8 @@ void testParticlesCustomExpressionSoftplus() {
atm->addGlobalParameter("Uh", uh); atm->addGlobalParameter("Uh", uh);
atm->addGlobalParameter("W0", w0); atm->addGlobalParameter("W0", w0);
atm->addParticle( nodispl ); atm->addParticle();
atm->addParticle( displ ); atm->addParticle(new ATMForce::FixedDisplacement(displ));
atm->addForce(bond); atm->addForce(bond);
system.addForce(atm); system.addForce(atm);
...@@ -460,7 +555,7 @@ void testParticlesCustomExpressionSoftplus() { ...@@ -460,7 +555,7 @@ void testParticlesCustomExpressionSoftplus() {
void testLargeSystem() { void testLargeSystem() {
// Create a system with lots of particles, each displaced differently. // Create a system with lots of particles, each displaced differently.
int numParticles = 1000; int numParticles = 1000;
System system; System system;
system.setDefaultPeriodicBoxVectors(Vec3(3, 0, 0), Vec3(0, 3, 0), Vec3(0, 0, 3)); system.setDefaultPeriodicBoxVectors(Vec3(3, 0, 0), Vec3(0, 3, 0), Vec3(0, 0, 3));
...@@ -477,7 +572,7 @@ void testLargeSystem() { ...@@ -477,7 +572,7 @@ void testLargeSystem() {
Vec3 d(genrand_real2(sfmt)-0.5, genrand_real2(sfmt)-0.5, genrand_real2(sfmt)-0.5); Vec3 d(genrand_real2(sfmt)-0.5, genrand_real2(sfmt)-0.5, genrand_real2(sfmt)-0.5);
displacements.push_back(d); displacements.push_back(d);
external->addParticle(i); external->addParticle(i);
atm->addParticle(d); atm->addParticle(new ATMForce::FixedDisplacement(d));
} }
// Also add nonbonded forces to trigger atom reordering on the GPU. // Also add nonbonded forces to trigger atom reordering on the GPU.
...@@ -494,7 +589,7 @@ void testLargeSystem() { ...@@ -494,7 +589,7 @@ void testLargeSystem() {
nb1->addParticle({(double) (i%3)}); nb1->addParticle({(double) (i%3)});
nb1->setNonbondedMethod(CustomNonbondedForce::CutoffPeriodic); nb1->setNonbondedMethod(CustomNonbondedForce::CutoffPeriodic);
atm->addForce(nb1); atm->addForce(nb1);
// Evaluate the forces to see if the particles are at the correct positions. // Evaluate the forces to see if the particles are at the correct positions.
VerletIntegrator integrator(1.0); VerletIntegrator integrator(1.0);
...@@ -512,6 +607,97 @@ void testLargeSystem() { ...@@ -512,6 +607,97 @@ void testLargeSystem() {
} }
} }
void testLargeSystemSwap() {
// Create a system with lots of particles in an external field
// that depends on atom indexes. Swap their positions, check
// energies and forces.
int numParticles = 1000;
System system;
system.setDefaultPeriodicBoxVectors(Vec3(3, 0, 0), Vec3(0, 3, 0), Vec3(0, 0, 3));
CustomExternalForce* external = new CustomExternalForce("qf*(x^2 + 2*y^2 + 3*z^2)");
external->addPerParticleParameter("qf");
ATMForce* atm = new ATMForce(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0);
atm->addForce(external);
system.addForce(atm);
OpenMM_SFMT::SFMT sfmt;
init_gen_rand(0, sfmt);
Vec3 nodispl = Vec3(0,0,0);
vector<Vec3> positions;
vector<int> target_particle(numParticles);
for (int i = 0; i < numParticles; i++) {
target_particle[i] = i;
}
auto rng = default_random_engine {};
shuffle(begin(target_particle), end(target_particle), rng);
vector<int> target_particle_inv(numParticles);
for (int i = 0; i < numParticles; i++) {
target_particle_inv[target_particle[i]] = i;
}
vector<double> qf(numParticles);
for (int i = 0; i < numParticles; i++)
qf[i] = (double)i/(double)numParticles;
for (int i = 0; i < numParticles; i++) {
system.addParticle(1.0);
positions.push_back(3*Vec3(genrand_real2(sfmt), genrand_real2(sfmt), genrand_real2(sfmt)));
external->addParticle(i, {qf[i]});
atm->addParticle(new ATMForce::ParticleOffsetDisplacement(target_particle[i], i));
}
double energy0 = 0.;
for (int i = 0; i < numParticles; i++) {
Vec3 pos = positions[i];
energy0 += qf[i]*(pos[0]*pos[0]+2*pos[1]*pos[1]+3*pos[2]*pos[2]);
}
double energy1 = 0.;
for (int i = 0; i < numParticles; i++) {
Vec3 pos = positions[target_particle[i]];
energy1 += qf[i]*(pos[0]*pos[0]+2*pos[1]*pos[1]+3*pos[2]*pos[2]);
}
// Also add nonbonded forces to trigger atom reordering on the GPU.
CustomNonbondedForce* nb = new CustomNonbondedForce("a*r^2");
nb->addGlobalParameter("a", 0.0);
for (int i = 0; i < numParticles; i++)
nb->addParticle();
nb->setNonbondedMethod(CustomNonbondedForce::CutoffPeriodic);
system.addForce(nb);
CustomNonbondedForce* nb1 = new CustomNonbondedForce("0");
nb1->addPerParticleParameter("b");
for (int i = 0; i < numParticles; i++)
nb1->addParticle({(double) (i%3)});
nb1->setNonbondedMethod(CustomNonbondedForce::CutoffPeriodic);
atm->addForce(nb1);
// Evaluate energies and forces at lambda 0 and 1
VerletIntegrator integrator(1.0);
Context context(system, integrator, platform);
context.setPositions(positions);
for (double lambda : {0.0, 1.0}) {
context.setParameter(ATMForce::Lambda1(), lambda);
context.setParameter(ATMForce::Lambda2(), lambda);
State state = context.getState(State::Energy | State::Forces);
double u1, u0, energy;
double epot = state.getPotentialEnergy();
atm->getPerturbationEnergy(context, u1, u0, energy);
ASSERT_EQUAL_TOL(u0, energy0, 1e-6);
ASSERT_EQUAL_TOL(u1, energy1, 1e-6);
ASSERT_EQUAL_TOL(u0+lambda*(u1-u0), epot, 1e-6);
for (int i = 0; i < numParticles; i++) {
int l;
if (lambda > 0){
l = target_particle_inv[i];
}else{
l = i;
}
Vec3 pos = positions[i];
Vec3 expectedForce(-2*pos[0], -4*pos[1], -6*pos[2]);
ASSERT_EQUAL_VEC(qf[l]*expectedForce, state.getForces()[i], 1e-6);
}
}
}
void testChangingBoxVectors() { void testChangingBoxVectors() {
// Create a periodic system with incorrect default box vectors. // Create a periodic system with incorrect default box vectors.
...@@ -530,7 +716,7 @@ void testChangingBoxVectors() { ...@@ -530,7 +716,7 @@ void testChangingBoxVectors() {
system.addParticle(1.0); system.addParticle(1.0);
positions.push_back(3*Vec3(genrand_real2(sfmt)-0.5, genrand_real2(sfmt)-0.5, genrand_real2(sfmt)-0.5)); positions.push_back(3*Vec3(genrand_real2(sfmt)-0.5, genrand_real2(sfmt)-0.5, genrand_real2(sfmt)-0.5));
force->addParticle(0.0, 0.1, 1.0); force->addParticle(0.0, 0.1, 1.0);
atm->addParticle(Vec3()); atm->addParticle();
for (int j = 0; j < i; j++) { for (int j = 0; j < i; j++) {
Vec3 delta = positions[i]-positions[j]; Vec3 delta = positions[i]-positions[j];
for (int k = 0; k < 3; k++) for (int k = 0; k < 3; k++)
...@@ -554,7 +740,7 @@ void testChangingBoxVectors() { ...@@ -554,7 +740,7 @@ void testChangingBoxVectors() {
void testMolecules() { void testMolecules() {
// Verify that ATMForce correctly propagates information about molecules // Verify that ATMForce correctly propagates information about molecules
// from the forces it contains. // from the forces it contains.
System system; System system;
for (int i = 0; i < 5; i++) for (int i = 0; i < 5; i++)
system.addParticle(1.0); system.addParticle(1.0);
...@@ -604,7 +790,7 @@ void testSimulation() { ...@@ -604,7 +790,7 @@ void testSimulation() {
system.addParticle(10.0); system.addParticle(10.0);
positions.push_back(Vec3(0.6*i, 0.6*j, 0.6*k)); positions.push_back(Vec3(0.6*i, 0.6*j, 0.6*k));
nb->addParticle(0, 0.3, 1.0); nb->addParticle(0, 0.3, 1.0);
atm->addParticle(Vec3()); atm->addParticle();
} }
atm->setParticleParameters(0, Vec3(0.3, 0, 0), Vec3(-0.3, 0, 0)); atm->setParticleParameters(0, Vec3(0.3, 0, 0), Vec3(-0.3, 0, 0));
...@@ -634,15 +820,18 @@ void runPlatformTests(); ...@@ -634,15 +820,18 @@ void runPlatformTests();
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
try { try {
initializeTests(argc, argv); initializeTests(argc, argv);
testAPI();
test2Particles(); test2Particles();
test3ParticlesSwap();
test2Particles2Displacement0(); test2Particles2Displacement0();
test2ParticlesSoftCore(); test2ParticlesSoftCore();
testNonbonded(); testNonbonded();
testNonbondedwithEndpointClash(); testNonbondedwithEndpointClash();
testParticlesCustomExpressionLinear(); testParticlesCustomExpressionLinear();
testParticlesCustomExpressionSoftplus(); testParticlesCustomExpressionSoftplus();
testLargeSystem(); testLargeSystem();
testChangingBoxVectors(); testLargeSystemSwap();
testChangingBoxVectors();
testMolecules(); testMolecules();
testSimulation(); testSimulation();
runPlatformTests(); runPlatformTests();
......
...@@ -38,7 +38,7 @@ def getText(subNodePath, node): ...@@ -38,7 +38,7 @@ def getText(subNodePath, node):
return s.strip() return s.strip()
def convertOpenMMPrefix(name): def convertOpenMMPrefix(name):
return name.replace('OpenMM::', 'OpenMM_') return name.replace('::', '_')
OPENMM_RE_PATTERN=re.compile("(.*)OpenMM:[a-zA-Z0-9_:]*:(.*)") OPENMM_RE_PATTERN=re.compile("(.*)OpenMM:[a-zA-Z0-9_:]*:(.*)")
def stripOpenMMPrefix(name, rePattern=OPENMM_RE_PATTERN): def stripOpenMMPrefix(name, rePattern=OPENMM_RE_PATTERN):
...@@ -101,6 +101,7 @@ class WrapperGenerator: ...@@ -101,6 +101,7 @@ class WrapperGenerator:
] ]
self.skipMethods = [s.replace(' ', '') for s in self.skipMethods] self.skipMethods = [s.replace(' ', '') for s in self.skipMethods]
self.hideClasses = ['Kernel', 'KernelImpl', 'KernelFactory', 'ContextImpl', 'SerializationNode', 'SerializationProxy'] self.hideClasses = ['Kernel', 'KernelImpl', 'KernelFactory', 'ContextImpl', 'SerializationNode', 'SerializationProxy']
self.renameTypes = {}
self.nodeByID={} self.nodeByID={}
# Read all the XML files and merge them into a single document. # Read all the XML files and merge them into a single document.
...@@ -142,6 +143,10 @@ class WrapperGenerator: ...@@ -142,6 +143,10 @@ class WrapperGenerator:
baseNode = self.getNodeByID(baseNodeID) baseNode = self.getNodeByID(baseNodeID)
self.findBaseNodes(baseNode, excludedClassNodes) self.findBaseNodes(baseNode, excludedClassNodes)
excludedClassNodes.append(node) excludedClassNodes.append(node)
for inner in findNodes(node, "innerclass", prot="public"):
fullName = getNodeText(inner)
shortName = fullName.rsplit("::")[-1]
self.renameTypes[shortName] = fullName
def getClassMethods(self, classNode): def getClassMethods(self, classNode):
className = getText("compoundname", classNode) className = getText("compoundname", classNode)
...@@ -298,8 +303,8 @@ class CHeaderGenerator(WrapperGenerator): ...@@ -298,8 +303,8 @@ class CHeaderGenerator(WrapperGenerator):
if methodName in nameCount: if methodName in nameCount:
# There are multiple methods with the same name. # There are multiple methods with the same name.
count = nameCount[methodName] count = nameCount[methodName]
methodName = "%s_%d" % (methodName, count)
nameCount[methodName] = count+1 nameCount[methodName] = count+1
methodName = "%s_%d" % (methodName, count)
else: else:
nameCount[methodName] = 1 nameCount[methodName] = 1
self.out.write("extern OPENMM_EXPORT %s %s_%s(" % (returnType, typeName, methodName)) self.out.write("extern OPENMM_EXPORT %s %s_%s(" % (returnType, typeName, methodName))
...@@ -564,14 +569,16 @@ class CSourceGenerator(WrapperGenerator): ...@@ -564,14 +569,16 @@ class CSourceGenerator(WrapperGenerator):
if methodName in nameCount: if methodName in nameCount:
# There are multiple methods with the same name. # There are multiple methods with the same name.
count = nameCount[methodName] count = nameCount[methodName]
methodName = "%s_%d" % (methodName, count)
nameCount[methodName] = count+1 nameCount[methodName] = count+1
methodName = "%s_%d" % (methodName, count)
else: else:
nameCount[methodName] = 1 nameCount[methodName] = 1
methodType = getText("type", methodNode) methodType = getText("type", methodNode)
returnType = self.getType(methodType) returnType = self.getType(methodType)
if methodType in self.classesByShortName: if methodType in self.classesByShortName:
methodType = self.classesByShortName[methodType] methodType = self.classesByShortName[methodType]
for key, value in self.renameTypes.items():
methodType = methodType.replace(key, value)
self.out.write("OPENMM_EXPORT %s %s_%s(" % (returnType, typeName, methodName)) self.out.write("OPENMM_EXPORT %s %s_%s(" % (returnType, typeName, methodName))
isInstanceMethod = (methodNode.attrib['static'] != 'yes') isInstanceMethod = (methodNode.attrib['static'] != 'yes')
if isInstanceMethod: if isInstanceMethod:
...@@ -662,7 +669,7 @@ class CSourceGenerator(WrapperGenerator): ...@@ -662,7 +669,7 @@ class CSourceGenerator(WrapperGenerator):
if wrappedType == type: if wrappedType == type:
return value; return value;
if type.endswith('*') or type.endswith('&'): if type.endswith('*') or type.endswith('&'):
return 'reinterpret_cast<%s>(%s)' % (wrappedType, value) return 'reinterpret_cast<%s>(%s)' % (wrappedType.replace('::', '_'), value)
return 'static_cast<%s>(%s)' % (wrappedType, value) return 'static_cast<%s>(%s)' % (wrappedType, value)
def unwrapValue(self, type, value): def unwrapValue(self, type, value):
...@@ -677,8 +684,23 @@ class CSourceGenerator(WrapperGenerator): ...@@ -677,8 +684,23 @@ class CSourceGenerator(WrapperGenerator):
return 'static_cast<%s>(%s)' % (self.classesByShortName[type], value) return 'static_cast<%s>(%s)' % (self.classesByShortName[type], value)
if type == 'bool': if type == 'bool':
return value return value
type = self.convertShortName(type)
return 'reinterpret_cast<%s>(%s)' % (type, value) return 'reinterpret_cast<%s>(%s)' % (type, value)
def convertShortName(self, shortName):
name = shortName
prefix = ''
suffix = ''
if name.endswith('&') or name.endswith('*'):
suffix = name[-1]
name = name[:-1]
if name.startswith('const '):
prefix = 'const '
name = name[6:]
if name.strip() in self.classesByShortName:
return f'{prefix}{self.classesByShortName[name.strip()]}{suffix}'
return shortName
def writeOutput(self): def writeOutput(self):
print(""" print("""
#include "OpenMM.h" #include "OpenMM.h"
...@@ -1020,6 +1042,7 @@ class FortranHeaderGenerator(WrapperGenerator): ...@@ -1020,6 +1042,7 @@ class FortranHeaderGenerator(WrapperGenerator):
# Write other methods # Write other methods
nameCount = {} nameCount = {}
allNames = set()
for methodNode in methodList: for methodNode in methodList:
methodName = methodNames[methodNode] methodName = methodNames[methodNode]
if methodName in (shortClassName, destructorName): if methodName in (shortClassName, destructorName):
...@@ -1033,8 +1056,8 @@ class FortranHeaderGenerator(WrapperGenerator): ...@@ -1033,8 +1056,8 @@ class FortranHeaderGenerator(WrapperGenerator):
if methodName in nameCount: if methodName in nameCount:
# There are multiple methods with the same name. # There are multiple methods with the same name.
count = nameCount[methodName] count = nameCount[methodName]
methodName = "%s_%d" % (methodName, count)
nameCount[methodName] = count+1 nameCount[methodName] = count+1
methodName = "%s_%d" % (methodName, count)
else: else:
nameCount[methodName] = 1 nameCount[methodName] = 1
returnType = self.getType(getText("type", methodNode)) returnType = self.getType(getText("type", methodNode))
...@@ -1042,6 +1065,10 @@ class FortranHeaderGenerator(WrapperGenerator): ...@@ -1042,6 +1065,10 @@ class FortranHeaderGenerator(WrapperGenerator):
hasReturnArg = not (hasReturnValue or returnType == 'void') hasReturnArg = not (hasReturnValue or returnType == 'void')
functionName = "%s_%s" % (typeName, methodName) functionName = "%s_%s" % (typeName, methodName)
functionName = functionName[:63] functionName = functionName[:63]
if functionName in allNames:
# Two functions get truncated to have the same name, so skip the later ones.
continue
allNames.add(functionName)
if hasReturnValue: if hasReturnValue:
self.out.write(" function ") self.out.write(" function ")
else: else:
...@@ -1585,6 +1612,7 @@ class FortranSourceGenerator(WrapperGenerator): ...@@ -1585,6 +1612,7 @@ class FortranSourceGenerator(WrapperGenerator):
# Write other methods # Write other methods
nameCount = {} nameCount = {}
allNames = set()
for methodNode in methodList: for methodNode in methodList:
methodName = methodNames[methodNode] methodName = methodNames[methodNode]
if methodName in (shortClassName, destructorName): if methodName in (shortClassName, destructorName):
...@@ -1600,12 +1628,16 @@ class FortranSourceGenerator(WrapperGenerator): ...@@ -1600,12 +1628,16 @@ class FortranSourceGenerator(WrapperGenerator):
if methodName in nameCount: if methodName in nameCount:
# There are multiple methods with the same name. # There are multiple methods with the same name.
count = nameCount[methodName] count = nameCount[methodName]
methodName = "%s_%d" % (methodName, count)
nameCount[methodName] = count+1 nameCount[methodName] = count+1
methodName = "%s_%d" % (methodName, count)
else: else:
nameCount[methodName] = 1 nameCount[methodName] = 1
functionName = "%s_%s" % (typeName, methodName) functionName = "%s_%s" % (typeName, methodName)
truncatedName = functionName[:63] truncatedName = functionName[:63]
if truncatedName in allNames:
# Two functions get truncated to have the same name, so skip the later ones.
continue
allNames.add(truncatedName)
self.writeOneMethod(classNode, methodNode, functionName, truncatedName.lower()+'_') self.writeOneMethod(classNode, methodNode, functionName, truncatedName.lower()+'_')
self.writeOneMethod(classNode, methodNode, functionName, truncatedName.upper()) self.writeOneMethod(classNode, methodNode, functionName, truncatedName.upper())
......
...@@ -45,6 +45,7 @@ using namespace OpenMM; ...@@ -45,6 +45,7 @@ using namespace OpenMM;
%} %}
%feature("flatnested", "1");
%feature("autodoc", "0"); %feature("autodoc", "0");
%nodefaultctor; %nodefaultctor;
......
...@@ -112,6 +112,11 @@ def stripOpenmmPrefix(name, rePattern=OPENMM_RE_PATTERN): ...@@ -112,6 +112,11 @@ def stripOpenmmPrefix(name, rePattern=OPENMM_RE_PATTERN):
except: except:
return name return name
def stripClassPrefix(className):
if className.startswith("OpenMM::"):
className = className[8:]
return className
def findNodes(parent, path, **args): def findNodes(parent, path, **args):
nodes = [] nodes = []
for node in parent.findall(path): for node in parent.findall(path):
...@@ -211,7 +216,7 @@ class SwigInputBuilder: ...@@ -211,7 +216,7 @@ class SwigInputBuilder:
# Read all the XML files and merge them into a single document. # Read all the XML files and merge them into a single document.
self.doc = etree.ElementTree(etree.Element('root')) self.doc = etree.ElementTree(etree.Element('root'))
for file in os.listdir(inputDirname): for file in sorted(os.listdir(inputDirname)):
if file.lower().endswith('xml'): if file.lower().endswith('xml'):
root = etree.parse(os.path.join(inputDirname, file)).getroot() root = etree.parse(os.path.join(inputDirname, file)).getroot()
for node in root: for node in root:
...@@ -271,6 +276,7 @@ class SwigInputBuilder: ...@@ -271,6 +276,7 @@ class SwigInputBuilder:
forceSubclassList = [] forceSubclassList = []
integratorSubclassList = [] integratorSubclassList = []
tabulatedFunctionSubclassList = [] tabulatedFunctionSubclassList = []
coordinateTransformationSubclassList = []
for classNode in findNodes(self.doc.getroot(), "compounddef", kind="class", prot="public"): for classNode in findNodes(self.doc.getroot(), "compounddef", kind="class", prot="public"):
className = getText("compoundname", classNode) className = getText("compoundname", classNode)
shortClassName=stripOpenmmPrefix(className) shortClassName=stripOpenmmPrefix(className)
...@@ -287,6 +293,8 @@ class SwigInputBuilder: ...@@ -287,6 +293,8 @@ class SwigInputBuilder:
integratorSubclassList.append(shortClassName) integratorSubclassList.append(shortClassName)
elif baseName == 'OpenMM::TabulatedFunction': elif baseName == 'OpenMM::TabulatedFunction':
tabulatedFunctionSubclassList.append(shortClassName) tabulatedFunctionSubclassList.append(shortClassName)
elif baseName == 'OpenMM::ATMForce::CoordinateTransformation':
coordinateTransformationSubclassList.append(shortClassName)
# We need to include subclasses of DrudeIntegrator, but not DrudeIntegrator itself. # We need to include subclasses of DrudeIntegrator, but not DrudeIntegrator itself.
integratorSubclassList.remove('DrudeIntegrator') integratorSubclassList.remove('DrudeIntegrator')
...@@ -340,6 +348,11 @@ class SwigInputBuilder: ...@@ -340,6 +348,11 @@ class SwigInputBuilder:
self.fOut.write(",\n OpenMM::%s" % name) self.fOut.write(",\n OpenMM::%s" % name)
self.fOut.write(");\n\n") self.fOut.write(");\n\n")
self.fOut.write("%factory(OpenMM::ATMForce::CoordinateTransformation& OpenMM::ATMForce::getParticleTransformation")
for name in sorted(coordinateTransformationSubclassList):
self.fOut.write(",\n OpenMM::ATMForce::%s" % name)
self.fOut.write(");\n\n")
for classNode in self._orderedClassNodes: for classNode in self._orderedClassNodes:
methodList=getClassMethodList(classNode, self.skipMethods) methodList=getClassMethodList(classNode, self.skipMethods)
for items in methodList: for items in methodList:
...@@ -381,7 +394,7 @@ class SwigInputBuilder: ...@@ -381,7 +394,7 @@ class SwigInputBuilder:
if isConstructors: if isConstructors:
hasConstructor=True hasConstructor=True
className = stripOpenmmPrefix(getText("compoundname", classNode)) className = stripClassPrefix(getText("compoundname", classNode))
# If has a constructor then tell swig tell to make a copy method # If has a constructor then tell swig tell to make a copy method
if hasConstructor: if hasConstructor:
self.fOut.write("%%copyctor %s ;\n" % className) self.fOut.write("%%copyctor %s ;\n" % className)
...@@ -391,7 +404,7 @@ class SwigInputBuilder: ...@@ -391,7 +404,7 @@ class SwigInputBuilder:
def writeClassDeclarations(self): def writeClassDeclarations(self):
self.fOut.write("\n/* Class Declarations */\n\n") self.fOut.write("\n/* Class Declarations */\n\n")
for classNode in self._orderedClassNodes: for classNode in self._orderedClassNodes:
className = stripOpenmmPrefix(getText("compoundname", classNode)) className = stripClassPrefix(getText("compoundname", classNode))
if self.fOutDocstring: if self.fOutDocstring:
dNode = classNode.find('detaileddescription') dNode = classNode.find('detaileddescription')
if dNode is not None: if dNode is not None:
...@@ -405,15 +418,21 @@ class SwigInputBuilder: ...@@ -405,15 +418,21 @@ class SwigInputBuilder:
for baseNodePnt in findNodes(classNode, "basecompoundref", prot="public"): for baseNodePnt in findNodes(classNode, "basecompoundref", prot="public"):
if "refid" in baseNodePnt.attrib: if "refid" in baseNodePnt.attrib:
baseName = stripOpenmmPrefix(getText(".", baseNodePnt)) baseName = stripClassPrefix(getText(".", baseNodePnt))
self.fOut.write(" : public %s" % baseName) self.fOut.write(" : public %s" % baseName)
self.fOut.write(" {\n") self.fOut.write(" {\n")
self.fOut.write("public:\n") self.fOut.write("public:\n")
self.writeInnerClasses(classNode)
self.writeEnumerations(classNode) self.writeEnumerations(classNode)
self.writeMethods(classNode) self.writeMethods(classNode)
self.fOut.write("};\n\n") self.fOut.write("};\n\n")
self.fOut.write("\n") self.fOut.write("\n")
def writeInnerClasses(self, classNode):
for inner in findNodes(classNode, "innerclass", prot="public"):
name = getNodeText(inner).split("::")[-1];
self.fOut.write(f" class {name};\n")
def writeEnumerations(self, classNode): def writeEnumerations(self, classNode):
enumNodes = [] enumNodes = []
for section in findNodes(classNode, "sectiondef", kind="public-type"): for section in findNodes(classNode, "sectiondef", kind="public-type"):
...@@ -441,6 +460,7 @@ class SwigInputBuilder: ...@@ -441,6 +460,7 @@ class SwigInputBuilder:
def writeMethods(self, classNode): def writeMethods(self, classNode):
methodList=getClassMethodList(classNode, self.skipMethods) methodList=getClassMethodList(classNode, self.skipMethods)
fullClassName = getText("compoundname", classNode)
#write only Constructors #write only Constructors
for items in methodList: for items in methodList:
...@@ -629,7 +649,7 @@ class SwigInputBuilder: ...@@ -629,7 +649,7 @@ class SwigInputBuilder:
if addText: if addText:
self.fOutPythonappend.write("%pythonappend") self.fOutPythonappend.write("%pythonappend")
self.fOutPythonappend.write(" OpenMM::%s::%s(" % key) self.fOutPythonappend.write(" %s::%s(" % (fullClassName, methName))
sepChar='' sepChar=''
outputIndex=0 outputIndex=0
for pNode in paramList: for pNode in paramList:
......
...@@ -131,6 +131,7 @@ STEAL_OWNERSHIP = {("Platform", "registerPlatform") : [0], ...@@ -131,6 +131,7 @@ STEAL_OWNERSHIP = {("Platform", "registerPlatform") : [0],
("System", "addForce") : [0], ("System", "addForce") : [0],
("System", "setVirtualSite") : [1], ("System", "setVirtualSite") : [1],
("ATMForce", "addForce") : [0], ("ATMForce", "addForce") : [0],
("ATMForce", "setParticleTransformation") : [1],
("CustomNonbondedForce", "addTabulatedFunction") : [1], ("CustomNonbondedForce", "addTabulatedFunction") : [1],
("CustomGBForce", "addTabulatedFunction") : [1], ("CustomGBForce", "addTabulatedFunction") : [1],
("CustomHbondForce", "addTabulatedFunction") : [1], ("CustomHbondForce", "addTabulatedFunction") : [1],
...@@ -533,4 +534,7 @@ UNITS = { ...@@ -533,4 +534,7 @@ UNITS = {
("ATMForce", "getDefaultUbcore") : ('unit.kilojoule_per_mole', ()), ("ATMForce", "getDefaultUbcore") : ('unit.kilojoule_per_mole', ()),
("ATMForce", "getDefaultAcore") : (None, ()), ("ATMForce", "getDefaultAcore") : (None, ()),
("ATMForce", "getParticleParameters") : (None, ("unit.nanometer", "unit.nanometer")), ("ATMForce", "getParticleParameters") : (None, ("unit.nanometer", "unit.nanometer")),
("ATMForce", "getParticleTransformation") : (None, ()),
("FixedDisplacement", "getFixedDisplacement1") : ("unit.nanometer", ()),
("FixedDisplacement", "getFixedDisplacement0") : ("unit.nanometer", ()),
} }
...@@ -427,7 +427,7 @@ int Py_SequenceToVecVecVecDouble(PyObject* obj, std::vector<std::vector<std::vec ...@@ -427,7 +427,7 @@ int Py_SequenceToVecVecVecDouble(PyObject* obj, std::vector<std::vector<std::vec
SWIG_fail; SWIG_fail;
} }
} }
%typemap(typecheck, fragment="Py_SequenceToVec3") Vec3 { %typemap(typecheck, precedence=SWIG_TYPECHECK_DOUBLE_ARRAY, fragment="Py_SequenceToVec3") Vec3 {
int res = 0; int res = 0;
Py_SequenceToVec3($input, res); Py_SequenceToVec3($input, res);
$1 = SWIG_IsOK(res); $1 = SWIG_IsOK(res);
...@@ -444,7 +444,7 @@ int Py_SequenceToVecVecVecDouble(PyObject* obj, std::vector<std::vector<std::vec ...@@ -444,7 +444,7 @@ int Py_SequenceToVecVecVecDouble(PyObject* obj, std::vector<std::vector<std::vec
} }
$1 = &myVec; $1 = &myVec;
} }
%typemap(typecheck, fragment="Py_SequenceToVec3") const Vec3& { %typemap(typecheck, precedence=SWIG_TYPECHECK_DOUBLE_ARRAY, fragment="Py_SequenceToVec3") const Vec3& {
int res = 0; int res = 0;
Py_SequenceToVec3($input, res); Py_SequenceToVec3($input, res);
$1 = SWIG_IsOK(res); $1 = SWIG_IsOK(res);
......
...@@ -529,7 +529,21 @@ class TestAPIUnits(unittest.TestCase): ...@@ -529,7 +529,21 @@ class TestAPIUnits(unittest.TestCase):
def testATMForce(self): def testATMForce(self):
"""Tests the ATMForce API features""" """Tests the ATMForce API features"""
force = ATMForce(0.1, 0.2, 0.3, 0.4, 0.5, 0.7, 0.6, 0.8, -1.0); force = ATMForce(0.1, 0.2, 0.3, 0.4, 0.5, 0.7, 0.6, 0.8, -1.0);
#particle 0: fixed displacements,
force.addParticle(Vec3(1, 2, 3), Vec3(4, 5, 6)) force.addParticle(Vec3(1, 2, 3), Vec3(4, 5, 6))
#particle 1: fixed displacements using a Transformation object
p = force.addParticle()
force.setParticleTransformation(p, FixedDisplacement(Vec3(7, 8, 9), Vec3(10, 11, 12)))
#particle 2: particle distance displacement
p = force.addParticle()
force.setParticleTransformation(p, ParticleOffsetDisplacement(1, 0))
#particle 3: stationary particle
force.addParticle()
self.assertEqual(0.1, force.getGlobalParameterDefaultValue(0)) self.assertEqual(0.1, force.getGlobalParameterDefaultValue(0))
self.assertEqual(0.2, force.getGlobalParameterDefaultValue(1)) self.assertEqual(0.2, force.getGlobalParameterDefaultValue(1))
self.assertEqual(0.3, force.getGlobalParameterDefaultValue(2)) self.assertEqual(0.3, force.getGlobalParameterDefaultValue(2))
...@@ -539,10 +553,32 @@ class TestAPIUnits(unittest.TestCase): ...@@ -539,10 +553,32 @@ class TestAPIUnits(unittest.TestCase):
self.assertEqual(0.6, force.getGlobalParameterDefaultValue(6)) self.assertEqual(0.6, force.getGlobalParameterDefaultValue(6))
self.assertEqual(0.8, force.getGlobalParameterDefaultValue(7)) self.assertEqual(0.8, force.getGlobalParameterDefaultValue(7))
self.assertEqual(-1.0, force.getGlobalParameterDefaultValue(8)) self.assertEqual(-1.0, force.getGlobalParameterDefaultValue(8))
d1, d0 = force.getParticleParameters(0) d1, d0 = force.getParticleParameters(0)
self.assertEqual(Vec3(1, 2, 3)*nanometers, d1) self.assertEqual(Vec3(1, 2, 3)*nanometers, d1)
self.assertEqual(Vec3(4, 5, 6)*nanometers, d0) self.assertEqual(Vec3(4, 5, 6)*nanometers, d0)
fixed_displacement_transformation = force.getParticleTransformation(1)
d1 = fixed_displacement_transformation.getFixedDisplacement1()
d0 = fixed_displacement_transformation.getFixedDisplacement0()
self.assertEqual(Vec3(7, 8, 9)*nanometers, d1)
self.assertEqual(Vec3(10, 11, 12)*nanometers, d0)
vectordistance_displacement_transformation = force.getParticleTransformation(2)
j1 = vectordistance_displacement_transformation.getDestinationParticle1()
i1 = vectordistance_displacement_transformation.getOriginParticle1()
j0 = vectordistance_displacement_transformation.getDestinationParticle0()
i0 = vectordistance_displacement_transformation.getOriginParticle0()
self.assertEqual( 1, j1)
self.assertEqual( 0, i1)
self.assertEqual(-1, j0)
self.assertEqual(-1, i0)
transformation = force.getParticleTransformation(3)
d1, d0 = force.getParticleParameters(3)
self.assertEqual(Vec3(0, 0, 0)*nanometers, d1)
self.assertEqual(Vec3(0, 0, 0)*nanometers, d0)
def testDrudeForce(self): def testDrudeForce(self):
""" Tests the DrudeForce API features """ """ Tests the DrudeForce API features """
force = DrudeForce() force = DrudeForce()
......
...@@ -19,8 +19,9 @@ class TestATMForce(unittest.TestCase): ...@@ -19,8 +19,9 @@ class TestATMForce(unittest.TestCase):
system.addForce(nbforce) system.addForce(nbforce)
atmforce = ATMForce(0.5, 0.5, 0, 0, 0, 0, 0, 0, 1.0) atmforce = ATMForce(0.5, 0.5, 0, 0, 0, 0, 0, 0, 1.0)
atmforce.addParticle(Vec3(0., 0., 0.)) atmforce.addParticle()
atmforce.addParticle(Vec3(1., 0., 0.)) p = atmforce.addParticle()
atmforce.setParticleTransformation(p, FixedDisplacement(Vec3(1., 0., 0.)))
atmforce.addForce(copy.copy(nbforce)) atmforce.addForce(copy.copy(nbforce))
system.removeForce(0) system.removeForce(0)
...@@ -48,3 +49,54 @@ class TestATMForce(unittest.TestCase): ...@@ -48,3 +49,54 @@ class TestATMForce(unittest.TestCase):
epert_expected = 69.4062*kilojoules_per_mole epert_expected = 69.4062*kilojoules_per_mole
assert( abs(epot-epot_expected) < 1.e-3*kilojoules_per_mole ) assert( abs(epot-epot_expected) < 1.e-3*kilojoules_per_mole )
assert( abs(epert-epert_expected) < 1.e-3*kilojoules_per_mole ) assert( abs(epert-epert_expected) < 1.e-3*kilojoules_per_mole )
assert isinstance(atmforce.getParticleTransformation(0), FixedDisplacement)
def test3ParticlesNonbondedSwap(self):
"""Test coordinate swap"""
system = System()
system.addParticle(1.0)
system.addParticle(1.0)
system.addParticle(1.0)
nbforce = NonbondedForce();
nbforce.addParticle( 0.0, 1.0, 1.0)
nbforce.addParticle( 1.0, 1.0, 1.0)
nbforce.addParticle(-1.0, 1.0, 1.0)
atmforce = ATMForce(0.5, 0.5, 0, 0, 0, 0, 0, 0, 1.0)
atmforce.addParticle() #particle 0 is not displaced
#particle 1's coordinate is swapped with 2
p = atmforce.addParticle()
atmforce.setParticleTransformation(p, ParticleOffsetDisplacement(2, 1))
#particle 2's coordinate is swapped with 1
p = atmforce.addParticle()
atmforce.setParticleTransformation(p, ParticleOffsetDisplacement(1, 2))
atmforce.addForce(nbforce)
system.addForce(atmforce)
integrator = VerletIntegrator(1.0)
platform = Platform.getPlatform('Reference')
context = Context(system, integrator, platform)
positions = []
positions.append(Vec3( 0., 0., 0.))
positions.append(Vec3( 1., 0., 0.))
positions.append(Vec3(-1., 0., 0.))
context.setPositions(positions)
state = context.getState(getEnergy = True, getForces = True)
epot = state.getPotentialEnergy()
(u1, u0, energy) = atmforce.getPerturbationEnergy(context)
epert = u1 - u0
#print("Potential energy = ", epot)
#print("ATM perturbation energy = ", epert)
epot_expected = -69.52925*kilojoules_per_mole
epert_expected = 0*kilojoules_per_mole
assert( abs(epot-epot_expected) < 1.e-3*kilojoules_per_mole )
assert( abs(epert-epert_expected) < 1.e-3*kilojoules_per_mole )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment