Unverified Commit 2fbed592 authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Created PythonForce (#5122)

* Initial implementation of PythonForce

* Continuing implementation of PythonForce

* Tests for PythonForce

* Fix memory leaks

* Documentation for PythonForce

* Fixed incorrect return type

* Fix compilation error on Python older than 3.12

* Handle all dtypes

* Optimizations to PythonForce

* Optimized getPositions()

* Test all platforms

* Fix test failures
parent 74912095
...@@ -623,7 +623,8 @@ EXCLUDE_PATTERNS = @CMAKE_SOURCE_DIR@/*/tests/* \ ...@@ -623,7 +623,8 @@ EXCLUDE_PATTERNS = @CMAKE_SOURCE_DIR@/*/tests/* \
*RpmdKernels.h \ *RpmdKernels.h \
*RPMDUpdater.h \ *RPMDUpdater.h \
*OpenMMFortranModule.f90 \ *OpenMMFortranModule.f90 \
*OpenMMCWrapper.h *OpenMMCWrapper.h \
*PythonForce.h
# The EXCLUDE_SYMBOLS tag can be used to specify one or more symbol names # The EXCLUDE_SYMBOLS tag can be used to specify one or more symbol names
# (namespaces, classes, functions, etc.) that should be excluded from the # (namespaces, classes, functions, etc.) that should be excluded from the
......
...@@ -30,7 +30,7 @@ copyright = u"2015-2025, Stanford University and the Authors" ...@@ -30,7 +30,7 @@ copyright = u"2015-2025, Stanford University and the Authors"
version = openmm.version.short_version version = openmm.version.short_version
release = openmm.version.full_version release = openmm.version.full_version
exclude_patterns = ["_build", "_templates"] exclude_patterns = ["_build", "_templates", "**/*.ComputationWrapper.*", "**/*.PythonForceProxy.*"]
html_static_path = ["_static"] html_static_path = ["_static"]
templates_path = ["_templates"] templates_path = ["_templates"]
......
...@@ -59,7 +59,9 @@ def library_template_variables(): ...@@ -59,7 +59,9 @@ def library_template_variables():
# these classes are useless and not worth documenting. # these classes are useless and not worth documenting.
exclude.extend([ exclude.extend([
'openmm.openmm.SwigPyIterator', 'openmm.openmm.SwigPyIterator',
'openmm.openmm.OpenMMException']) 'openmm.openmm.OpenMMException',
'openmm.openmm.ComputationWrapper',
'openmm.openmm.PythonForceProxy'])
for _, klass in mm_klasses: for _, klass in mm_klasses:
full = fullname(klass) full = fullname(klass)
......
...@@ -57,6 +57,7 @@ ...@@ -57,6 +57,7 @@
#include "openmm/MonteCarloBarostat.h" #include "openmm/MonteCarloBarostat.h"
#include "openmm/OrientationRestraintForce.h" #include "openmm/OrientationRestraintForce.h"
#include "openmm/PeriodicTorsionForce.h" #include "openmm/PeriodicTorsionForce.h"
#include "openmm/PythonForce.h"
#include "openmm/QTBIntegrator.h" #include "openmm/QTBIntegrator.h"
#include "openmm/RBTorsionForce.h" #include "openmm/RBTorsionForce.h"
#include "openmm/RGForce.h" #include "openmm/RGForce.h"
...@@ -167,9 +168,12 @@ public: ...@@ -167,9 +168,12 @@ public:
/** /**
* Get the positions of all particles. * Get the positions of all particles.
* *
* @param context the context in which to execute this kernel
* @param positions on exit, this contains the particle positions * @param positions on exit, this contains the particle positions
* @param allowPeriodic if true, the returned positions might be translated into a
* different periodic box to keep them closer to the origin
*/ */
virtual void getPositions(ContextImpl& context, std::vector<Vec3>& positions) = 0; virtual void getPositions(ContextImpl& context, std::vector<Vec3>& positions, bool allowPeriodic=false) = 0;
/** /**
* Set the positions of all particles. * Set the positions of all particles.
* *
...@@ -1941,6 +1945,34 @@ public: ...@@ -1941,6 +1945,34 @@ public:
virtual double execute(ContextImpl& context, bool includeForces, bool includeEnergy) = 0; virtual double execute(ContextImpl& context, bool includeForces, bool includeEnergy) = 0;
}; };
/**
* This kernel is invoked by PythonForce to calculate the forces acting on the system and the energy of the system.
*/
class CalcPythonForceKernel : public KernelImpl {
public:
static std::string Name() {
return "CalcPythonForce";
}
CalcPythonForceKernel(std::string name, const Platform& platform) : KernelImpl(name, platform) {
}
/**
* Initialize the kernel.
*
* @param system the System this kernel will be applied to
* @param force the PythonForce this kernel will be used for
*/
virtual void initialize(const System& system, const PythonForce& force) = 0;
/**
* Execute the kernel to calculate the forces and/or energy.
*
* @param context the context in which to execute this kernel
* @param includeForces true if forces should be calculated
* @param includeEnergy true if the energy should be calculated
* @return the potential energy due to the force
*/
virtual double execute(ContextImpl& context, bool includeForces, bool includeEnergy) = 0;
};
} // namespace OpenMM } // namespace OpenMM
#endif /*OPENMM_KERNELS_H_*/ #endif /*OPENMM_KERNELS_H_*/
...@@ -68,6 +68,7 @@ ...@@ -68,6 +68,7 @@
#include "openmm/OpenMMException.h" #include "openmm/OpenMMException.h"
#include "openmm/OrientationRestraintForce.h" #include "openmm/OrientationRestraintForce.h"
#include "openmm/PeriodicTorsionForce.h" #include "openmm/PeriodicTorsionForce.h"
#include "openmm/PythonForce.h"
#include "openmm/QTBIntegrator.h" #include "openmm/QTBIntegrator.h"
#include "openmm/RBTorsionForce.h" #include "openmm/RBTorsionForce.h"
#include "openmm/RGForce.h" #include "openmm/RGForce.h"
......
...@@ -100,7 +100,7 @@ namespace OpenMM { ...@@ -100,7 +100,7 @@ namespace OpenMM {
* *
* \endverbatim * \endverbatim
* *
* where the "base" values are the ones specified by addParticle() and "oaram" is the current value * where the "base" values are the ones specified by addParticle() and "param" is the current value
* of the Context parameter. A single Context parameter can apply offsets to multiple particles, * of the Context parameter. A single Context parameter can apply offsets to multiple particles,
* and multiple parameters can be used to apply offsets to the same particle. Parameters can also be used * and multiple parameters can be used to apply offsets to the same particle. Parameters can also be used
* to modify exceptions in exactly the same way by calling addExceptionParameterOffset(). * to modify exceptions in exactly the same way by calling addExceptionParameterOffset().
......
#ifndef OPENMM_PYTHONFORCE_H_
#define OPENMM_PYTHONFORCE_H_
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. *
* *
* Portions copyright (c) 2025 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#include "Force.h"
#include "State.h"
#include <map>
#include <string>
#include "internal/windowsExport.h"
namespace OpenMM {
/**
* This abstract class represents an interface for performing a computation. It is not intended to
* be used or subclassed directly by users. The Python wrapper contains a subclass that implements
* the interface using a Python function.
* @private
*/
class OPENMM_EXPORT PythonForceComputation {
public:
PythonForceComputation() {
}
virtual ~PythonForceComputation() {
}
/**
* Compute forces and energy. The State contains particle positions, parameters, and
* optionally periodic box vectors. Implementations should store the potential energy
* and particle forces into the energy and forces arguments. The forces argument points
* to an array of length 3*particles. Its type is either float or double, depending on
* the value of forcesAreDouble.
*/
virtual void compute(const State& state, double& energy, void* forces, bool forcesAreDouble) const = 0;
};
/**
* This class provides a mechanism for computing forces and energy with Python code. To use it,
* define a Python function that takes a State object as its only argument. The State contains
* particle positions and global parameters. Based on it, the function should compute the
* potential energy and forces, returning them as its two return values. The forces should be
* represented as a NumPy array of shape (# particles, 3). For example,
*
* \verbatim embed:rst:leading-asterisk
* .. code-block:: python
*
* def compute(state):
* pos = state.getPositions(asNumpy=True).value_in_unit(nanometer)
* k = state.getParameters()['k']
* energy = k*np.sum(pos*pos)
* force = -0.5*k*pos
* return energy*kilojoules_per_mole, force*kilojoules_per_mole/nanometer
*
* \endverbatim
*
* Attaching units to the return values is optional. If units are omitted, the values are assumed
* to be in the default units (energy in kJ/mol, forces in kJ/mol/nm).
*
* Now create a Python force, passing the function to the constructor. If you want the force
* to depend on global parameters, pass a dict as the second parameter with the names and default
* values of the parameters.
*
* \verbatim embed:rst:leading-asterisk
* .. code-block:: python
*
* force = PythonForce(compute, {'k':2.5})
*
* \endverbatim
*
* The default value of a parameter is its value in newly created Contexts. After a Context is
* created, you can change the values of parameters by calling setParameter() on it.
*
* The PythonForce cannot tell whether the function you provide makes use of periodic boundary
* conditions, so you must tell it. To make the force periodic, call
* setUsesPeriodicBoundaryConditions(True). This will cause usesPeriodicBoundaryConditions()
* to return True, and the State passed to the computation function will contain periodic
* box vectors. The positions may also be wrapped into a different periodic box to keep them
* closer to the origin and improve accuracy.
*
* When using XmlSerializer to save a PythonForce, it uses the Python pickle module to save
* the computation function. If it cannot be pickled, you will not be able to serialize the
* PythonForce. Functions defined at the top level of a module can usually be pickled, but local
* functions defined inside another function cannot.
*
* Compared to other types of forces, computing a force with Python code is slow and has high
* overhead. When possible, using a different force class is usually preferred. For example,
* the Python force shown in the example code above (a harmonic force attracting every particle
* to the origin) could be implemented just as easily with a CustomExternalForce, and would
* execute much faster if done that way.
*/
class OPENMM_EXPORT PythonForce : public Force {
public:
/**
* Create a PythonForce. This constructor is used internally, and is not intended for use
* by users. The Python wrapper defines an alternate constructor that takes a Python
* function instead of a PythonForceComputation.
*
* @param computation an object defining how the forces and energy should be computed
* @param globalParameters any global parameters used by the force. Keys are the parameter
* names, and the corresponding values are their default values.
* @private
*/
explicit PythonForce(PythonForceComputation* computation, const std::map<std::string, double>& globalParameters);
~PythonForce();
/**
* Get the PythonForceComputation that defines the computation.
* @private
*/
const PythonForceComputation& getComputation() const;
/**
* Get all global parameters defined by this force. Keys are the parameter names, and the
* corresponding values are their default values.
*/
const std::map<std::string, double>& getGlobalParameters() const;
/**
* Get the pickled representation of the computation function. If it cannot be pickled,
* this will be an empty vector.
*/
const std::vector<char>& getPickledFunction() const;
/**
* Set the pickled representation of the computation function. This is called automatically
* by the Python constructor.
* @private
*/
void setPickledFunction(char* function, int length);
/**
* Returns whether or not this force makes use of periodic boundary
* conditions.
*
* @returns true if force uses PBC and false otherwise
*/
bool usesPeriodicBoundaryConditions() const;
/**
* Set whether or not this force makes use of periodic boundary conditions.
* If this is set to true, periodic box vectors can be retrieved from the
* State passed to the computation function.
*/
void setUsesPeriodicBoundaryConditions(bool periodic);
protected:
ForceImpl* createImpl() const;
private:
PythonForceComputation* computation;
std::map<std::string, double> globalParameters;
bool usePeriodic;
std::vector<char> pickled;
};
} // namespace OpenMM
#endif /*OPENMM_PYTHONFORCE_H_*/
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
* This is part of the OpenMM molecular simulation toolkit. * * This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. * * See https://openmm.org/development. *
* * * *
* Portions copyright (c) 2008-2020 Stanford University and the Authors. * * Portions copyright (c) 2008-2025 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -142,6 +142,7 @@ private: ...@@ -142,6 +142,7 @@ private:
int types; int types;
double time, ke, pe; double time, ke, pe;
long long stepCount; long long stepCount;
bool hasBoxVectors;
std::vector<Vec3> positions; std::vector<Vec3> positions;
std::vector<Vec3> velocities; std::vector<Vec3> velocities;
std::vector<Vec3> forces; std::vector<Vec3> forces;
......
...@@ -102,8 +102,10 @@ public: ...@@ -102,8 +102,10 @@ public:
* Get the positions of all particles. * Get the positions of all particles.
* *
* @param positions on exit, this contains the particle positions * @param positions on exit, this contains the particle positions
* @param allowPeriodic if true, the returned positions might be translated into a
* different periodic box to keep them closer to the origin
*/ */
void getPositions(std::vector<Vec3>& positions); void getPositions(std::vector<Vec3>& positions, bool allowPeriodic=false);
/** /**
* Set the positions of all particles. * Set the positions of all particles.
* *
......
#ifndef OPENMM_PYTHONFORCEIMPL_H_
#define OPENMM_PYTHONFORCEIMPL_H_
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. *
* *
* Portions copyright (c) 2025 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#include "CustomCPPForceImpl.h"
#include "openmm/PythonForce.h"
#include "openmm/Kernel.h"
#include <utility>
#include <map>
#include <string>
namespace OpenMM {
/**
* This is the internal implementation of PythonForce.
*/
class PythonForceImpl : public ForceImpl {
public:
PythonForceImpl(const PythonForce& owner);
~PythonForceImpl();
void initialize(ContextImpl& context);
const PythonForce& getOwner() const {
return owner;
}
void updateContextState(ContextImpl& context, bool& forcesInvalid) {
// This force field doesn't update the state directly.
}
double calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups);
std::map<std::string, double> getDefaultParameters();
std::vector<std::string> getKernelNames();
std::vector<std::pair<int, int> > getBondedParticles() const {
return {};
}
private:
const PythonForce& owner;
const PythonForceComputation& computation;
std::map<std::string, double> defaultParameters;
bool usePeriodic;
Kernel kernel;
};
} // namespace OpenMM
#endif /*OPENMM_PYTHONFORCEIMPL_H_*/
...@@ -116,7 +116,7 @@ State Context::getState(int types, bool enforcePeriodicBox, int groups) const { ...@@ -116,7 +116,7 @@ State Context::getState(int types, bool enforcePeriodicBox, int groups) const {
} }
if (types&State::Positions) { if (types&State::Positions) {
vector<Vec3> positions; vector<Vec3> positions;
impl->getPositions(positions); impl->getPositions(positions, enforcePeriodicBox);
if (enforcePeriodicBox) { if (enforcePeriodicBox) {
const vector<vector<int> >& molecules = impl->getMolecules(); const vector<vector<int> >& molecules = impl->getMolecules();
for (auto& mol : molecules) { for (auto& mol : molecules) {
......
...@@ -223,8 +223,8 @@ void ContextImpl::setStepCount(long long count) { ...@@ -223,8 +223,8 @@ void ContextImpl::setStepCount(long long count) {
updateStateDataKernel.getAs<UpdateStateDataKernel>().setStepCount(*this, count); updateStateDataKernel.getAs<UpdateStateDataKernel>().setStepCount(*this, count);
} }
void ContextImpl::getPositions(std::vector<Vec3>& positions) { void ContextImpl::getPositions(std::vector<Vec3>& positions, bool allowPeriodic) {
updateStateDataKernel.getAs<UpdateStateDataKernel>().getPositions(*this, positions); updateStateDataKernel.getAs<UpdateStateDataKernel>().getPositions(*this, positions, allowPeriodic);
} }
void ContextImpl::setPositions(const std::vector<Vec3>& positions) { void ContextImpl::setPositions(const std::vector<Vec3>& positions) {
......
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. *
* *
* Portions copyright (c) 2025 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#include "openmm/PythonForce.h"
#include "openmm/internal/PythonForceImpl.h"
using namespace OpenMM;
using namespace std;
PythonForce::PythonForce(PythonForceComputation* computation, const map<string, double>& globalParameters) :
computation(computation), globalParameters(globalParameters), usePeriodic(false) {
}
PythonForce::~PythonForce() {
delete computation;
}
const PythonForceComputation& PythonForce::getComputation() const {
return *computation;
}
const map<string, double>& PythonForce::getGlobalParameters() const {
return globalParameters;
}
bool PythonForce::usesPeriodicBoundaryConditions() const {
return usePeriodic;
}
void PythonForce::setUsesPeriodicBoundaryConditions(bool periodic) {
usePeriodic = periodic;
}
const vector<char>& PythonForce::getPickledFunction() const {
return pickled;
}
void PythonForce::setPickledFunction(char* function, int length) {
pickled.clear();
for (int i = 0; i < length; i++)
pickled.push_back(function[i]);
}
ForceImpl* PythonForce::createImpl() const {
return new PythonForceImpl(*this);
}
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. *
* *
* Portions copyright (c) 2025 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#include "openmm/OpenMMException.h"
#include "openmm/internal/ContextImpl.h"
#include "openmm/internal/PythonForceImpl.h"
#include "openmm/kernels.h"
#include <set>
#include <sstream>
using namespace OpenMM;
using namespace std;
PythonForceImpl::PythonForceImpl(const PythonForce& owner) : owner(owner), computation(owner.getComputation()),
defaultParameters(owner.getGlobalParameters()), usePeriodic(owner.usesPeriodicBoundaryConditions()) {
forceGroup = owner.getForceGroup();
}
PythonForceImpl::~PythonForceImpl() {
}
void PythonForceImpl::initialize(ContextImpl& context) {
kernel = context.getPlatform().createKernel(CalcPythonForceKernel::Name(), context);
kernel.getAs<CalcPythonForceKernel>().initialize(context.getSystem(), owner);
}
double PythonForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) {
if ((groups&(1<<forceGroup)) != 0)
return kernel.getAs<CalcPythonForceKernel>().execute(context, includeForces, includeEnergy);
return 0.0;
}
vector<string> PythonForceImpl::getKernelNames() {
return {CalcCustomCPPForceKernel::Name()};
}
map<string, double> PythonForceImpl::getDefaultParameters() {
return defaultParameters;
}
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* This is part of the OpenMM molecular simulation toolkit. * * This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. * * See https://openmm.org/development. *
* * * *
* Portions copyright (c) 2008-2016 Stanford University and the Authors. * * Portions copyright (c) 2008-2025 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -65,11 +65,15 @@ double State::getPotentialEnergy() const { ...@@ -65,11 +65,15 @@ double State::getPotentialEnergy() const {
return pe; return pe;
} }
void State::getPeriodicBoxVectors(Vec3& a, Vec3& b, Vec3& c) const { void State::getPeriodicBoxVectors(Vec3& a, Vec3& b, Vec3& c) const {
if (!hasBoxVectors)
throw OpenMMException("Invoked getPeriodicBoxVectors() on a State which does not contain box vectors.");
a = periodicBoxVectors[0]; a = periodicBoxVectors[0];
b = periodicBoxVectors[1]; b = periodicBoxVectors[1];
c = periodicBoxVectors[2]; c = periodicBoxVectors[2];
} }
double State::getPeriodicBoxVolume() const { double State::getPeriodicBoxVolume() const {
if (!hasBoxVectors)
throw OpenMMException("Invoked getPeriodicBoxVolume() on a State which does not contain box vectors.");
return periodicBoxVectors[0].dot(periodicBoxVectors[1].cross(periodicBoxVectors[2])); return periodicBoxVectors[0].dot(periodicBoxVectors[1].cross(periodicBoxVectors[2]));
} }
const map<string, double>& State::getParameters() const { const map<string, double>& State::getParameters() const {
...@@ -84,7 +88,7 @@ const map<string, double>& State::getEnergyParameterDerivatives() const { ...@@ -84,7 +88,7 @@ const map<string, double>& State::getEnergyParameterDerivatives() const {
} }
const SerializationNode& State::getIntegratorParameters() const { const SerializationNode& State::getIntegratorParameters() const {
if ((types&IntegratorParameters) == 0) if ((types&IntegratorParameters) == 0)
throw OpenMMException("Invoked getPIntegratorarameters() on a State which does not contain integrator parameters."); throw OpenMMException("Invoked getIntegratorParameters() on a State which does not contain integrator parameters.");
return integratorParameters; return integratorParameters;
} }
SerializationNode& State::updateIntegratorParameters() { SerializationNode& State::updateIntegratorParameters() {
...@@ -95,9 +99,9 @@ SerializationNode& State::updateIntegratorParameters() { ...@@ -95,9 +99,9 @@ SerializationNode& State::updateIntegratorParameters() {
int State::getDataTypes() const { int State::getDataTypes() const {
return types; return types;
} }
State::State(double time, long long stepCount) : types(0), time(time), stepCount(stepCount), ke(0), pe(0) { State::State(double time, long long stepCount) : types(0), time(time), stepCount(stepCount), ke(0), pe(0), hasBoxVectors(false) {
} }
State::State() : types(0), time(0.0), ke(0), pe(0) { State::State() : types(0), time(0.0), ke(0), pe(0), hasBoxVectors(false) {
} }
void State::setPositions(const std::vector<Vec3>& pos) { void State::setPositions(const std::vector<Vec3>& pos) {
positions = pos; positions = pos;
...@@ -131,6 +135,7 @@ void State::setEnergy(double kinetic, double potential) { ...@@ -131,6 +135,7 @@ void State::setEnergy(double kinetic, double potential) {
} }
void State::setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c) { void State::setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c) {
hasBoxVectors = true;
periodicBoxVectors[0] = a; periodicBoxVectors[0] = a;
periodicBoxVectors[1] = b; periodicBoxVectors[1] = b;
periodicBoxVectors[2] = c; periodicBoxVectors[2] = c;
......
...@@ -80,9 +80,12 @@ public: ...@@ -80,9 +80,12 @@ public:
/** /**
* Get the positions of all particles. * Get the positions of all particles.
* *
* @param context the context in which to execute this kernel
* @param positions on exit, this contains the particle positions * @param positions on exit, this contains the particle positions
* @param allowPeriodic if true, the returned positions might be translated into a
* different periodic box to keep them closer to the origin
*/ */
void getPositions(ContextImpl& context, std::vector<Vec3>& positions); void getPositions(ContextImpl& context, std::vector<Vec3>& positions, bool allowPeriodic=false);
/** /**
* Set the positions of all particles. * Set the positions of all particles.
* *
...@@ -1459,6 +1462,58 @@ private: ...@@ -1459,6 +1462,58 @@ private:
double energy; double energy;
}; };
/**
* This kernel is invoked by PythonForce to calculate the forces acting on the system and the energy of the system.
*/
class CommonCalcPythonForceKernel : public CalcPythonForceKernel {
public:
CommonCalcPythonForceKernel(std::string name, const Platform& platform, OpenMM::ContextImpl& contextImpl, ComputeContext& cc) :
CalcPythonForceKernel(name, platform), contextImpl(contextImpl), cc(cc) {
}
/**
* Initialize the kernel.
*
* @param system the System this kernel will be applied to
* @param force the PythonForce this kernel will be used for
*/
void initialize(const System& system, const PythonForce& force);
/**
* Execute the kernel to calculate the forces and/or energy.
*
* @param context the context in which to execute this kernel
* @param includeForces true if forces should be calculated
* @param includeEnergy true if the energy should be calculated
* @return the potential energy due to the force
*/
double execute(ContextImpl& context, bool includeForces, bool includeEnergy);
/**
* The is called by the pre-computation to start the calculation running.
*/
void beginComputation(bool includeForces, bool includeEnergy, int groups);
/**
* This is called by the worker thread to do the computation.
*/
void executeOnWorkerThread(bool includeForces);
/**
* This is called by the post-computation to add the forces to the main array.
*/
double addForces(bool includeForces, bool includeEnergy, int groups);
private:
class ExecuteTask;
class StartCalculationPreComputation;
class AddForcesPostComputation;
OpenMM::ContextImpl& contextImpl;
ComputeContext& cc;
const PythonForceComputation* computation;
ComputeArray forcesArray;
ComputeKernel addForcesKernel;
std::vector<Vec3> positionsVec;
std::vector<double> forcesVec;
int forceGroupFlag;
double energy;
bool usePeriodic;
};
} // namespace OpenMM } // namespace OpenMM
#endif /*OPENMM_COMMONKERNELS_H_*/ #endif /*OPENMM_COMMONKERNELS_H_*/
...@@ -91,7 +91,7 @@ void CommonUpdateStateDataKernel::setStepCount(const ContextImpl& context, long ...@@ -91,7 +91,7 @@ void CommonUpdateStateDataKernel::setStepCount(const ContextImpl& context, long
ctx->setStepCount(count); ctx->setStepCount(count);
} }
void CommonUpdateStateDataKernel::getPositions(ContextImpl& context, vector<Vec3>& positions) { void CommonUpdateStateDataKernel::getPositions(ContextImpl& context, vector<Vec3>& positions, bool allowPeriodic) {
ContextSelector selector(cc); ContextSelector selector(cc);
int numParticles = context.getSystem().getNumParticles(); int numParticles = context.getSystem().getNumParticles();
positions.resize(numParticles); positions.resize(numParticles);
...@@ -123,29 +123,55 @@ void CommonUpdateStateDataKernel::getPositions(ContextImpl& context, vector<Vec3 ...@@ -123,29 +123,55 @@ void CommonUpdateStateDataKernel::getPositions(ContextImpl& context, vector<Vec3
int numThreads = threads.getNumThreads(); int numThreads = threads.getNumThreads();
int start = threadIndex*numParticles/numThreads; int start = threadIndex*numParticles/numThreads;
int end = (threadIndex+1)*numParticles/numThreads; int end = (threadIndex+1)*numParticles/numThreads;
if (cc.getUseDoublePrecision()) { if (allowPeriodic) {
mm_double4* posq = (mm_double4*) cc.getPinnedBuffer(); if (cc.getUseDoublePrecision()) {
for (int i = start; i < end; ++i) { mm_double4* posq = (mm_double4*) cc.getPinnedBuffer();
mm_double4 pos = posq[i]; for (int i = start; i < end; ++i) {
mm_int4 offset = cc.getPosCellOffsets()[i]; mm_double4 pos = posq[i];
positions[order[i]] = Vec3(pos.x, pos.y, pos.z)-boxVectors[0]*offset.x-boxVectors[1]*offset.y-boxVectors[2]*offset.z; positions[order[i]] = Vec3(pos.x, pos.y, pos.z);
}
} }
} else if (cc.getUseMixedPrecision()) {
else if (cc.getUseMixedPrecision()) { mm_float4* posq = (mm_float4*) cc.getPinnedBuffer();
mm_float4* posq = (mm_float4*) cc.getPinnedBuffer(); for (int i = start; i < end; ++i) {
for (int i = start; i < end; ++i) { mm_float4 pos1 = posq[i];
mm_float4 pos1 = posq[i]; mm_float4 pos2 = posCorrection[i];
mm_float4 pos2 = posCorrection[i]; positions[order[i]] = Vec3((double)pos1.x+(double)pos2.x, (double)pos1.y+(double)pos2.y, (double)pos1.z+(double)pos2.z);
mm_int4 offset = cc.getPosCellOffsets()[i]; }
positions[order[i]] = Vec3((double)pos1.x+(double)pos2.x, (double)pos1.y+(double)pos2.y, (double)pos1.z+(double)pos2.z)-boxVectors[0]*offset.x-boxVectors[1]*offset.y-boxVectors[2]*offset.z; }
else {
mm_float4* posq = (mm_float4*) cc.getPinnedBuffer();
for (int i = start; i < end; ++i) {
mm_float4 pos = posq[i];
positions[order[i]] = Vec3(pos.x, pos.y, pos.z);
}
} }
} }
else { else {
mm_float4* posq = (mm_float4*) cc.getPinnedBuffer(); if (cc.getUseDoublePrecision()) {
for (int i = start; i < end; ++i) { mm_double4* posq = (mm_double4*) cc.getPinnedBuffer();
mm_float4 pos = posq[i]; for (int i = start; i < end; ++i) {
mm_int4 offset = cc.getPosCellOffsets()[i]; mm_double4 pos = posq[i];
positions[order[i]] = Vec3(pos.x, pos.y, pos.z)-boxVectors[0]*offset.x-boxVectors[1]*offset.y-boxVectors[2]*offset.z; mm_int4 offset = cc.getPosCellOffsets()[i];
positions[order[i]] = Vec3(pos.x, pos.y, pos.z)-boxVectors[0]*offset.x-boxVectors[1]*offset.y-boxVectors[2]*offset.z;
}
}
else if (cc.getUseMixedPrecision()) {
mm_float4* posq = (mm_float4*) cc.getPinnedBuffer();
for (int i = start; i < end; ++i) {
mm_float4 pos1 = posq[i];
mm_float4 pos2 = posCorrection[i];
mm_int4 offset = cc.getPosCellOffsets()[i];
positions[order[i]] = Vec3((double)pos1.x+(double)pos2.x, (double)pos1.y+(double)pos2.y, (double)pos1.z+(double)pos2.z)-boxVectors[0]*offset.x-boxVectors[1]*offset.y-boxVectors[2]*offset.z;
}
}
else {
mm_float4* posq = (mm_float4*) cc.getPinnedBuffer();
for (int i = start; i < end; ++i) {
mm_float4 pos = posq[i];
mm_int4 offset = cc.getPosCellOffsets()[i];
positions[order[i]] = Vec3(pos.x, pos.y, pos.z)-boxVectors[0]*offset.x-boxVectors[1]*offset.y-boxVectors[2]*offset.z;
}
} }
} }
}); });
...@@ -4468,3 +4494,125 @@ double CommonCalcCustomCPPForceKernel::addForces(bool includeForces, bool includ ...@@ -4468,3 +4494,125 @@ double CommonCalcCustomCPPForceKernel::addForces(bool includeForces, bool includ
return energy; return energy;
} }
class CommonCalcPythonForceKernel::StartCalculationPreComputation : public ComputeContext::ForcePreComputation {
public:
StartCalculationPreComputation(CommonCalcPythonForceKernel& owner) : owner(owner) {
}
void computeForceAndEnergy(bool includeForces, bool includeEnergy, int groups) {
owner.beginComputation(includeForces, includeEnergy, groups);
}
CommonCalcPythonForceKernel& owner;
};
class CommonCalcPythonForceKernel::ExecuteTask : public ComputeContext::WorkTask {
public:
ExecuteTask(CommonCalcPythonForceKernel& owner, bool includeForces) : owner(owner), includeForces(includeForces) {
}
void execute() {
owner.executeOnWorkerThread(includeForces);
}
CommonCalcPythonForceKernel& owner;
bool includeForces;
};
class CommonCalcPythonForceKernel::AddForcesPostComputation : public ComputeContext::ForcePostComputation {
public:
AddForcesPostComputation(CommonCalcPythonForceKernel& owner) : owner(owner) {
}
double computeForceAndEnergy(bool includeForces, bool includeEnergy, int groups) {
return owner.addForces(includeForces, includeEnergy, groups);
}
CommonCalcPythonForceKernel& owner;
};
void CommonCalcPythonForceKernel::initialize(const System& system, const PythonForce& force) {
ContextSelector selector(cc);
computation = &force.getComputation();
usePeriodic = force.usesPeriodicBoundaryConditions();
int numParticles = system.getNumParticles();
positionsVec.resize(numParticles);
forcesVec.resize(3*numParticles);
int elementSize = (cc.getUseDoublePrecision() ? sizeof(double) : sizeof(float));
forcesArray.initialize(cc, 3*numParticles, elementSize, "forces");
map<string, string> defines;
defines["NUM_ATOMS"] = cc.intToString(numParticles);
defines["PADDED_NUM_ATOMS"] = cc.intToString(cc.getPaddedNumAtoms());
ComputeProgram program = cc.compileProgram(CommonKernelSources::customCppForce, defines);
addForcesKernel = program->createKernel("addForces");
addForcesKernel->addArg(forcesArray);
addForcesKernel->addArg(cc.getLongForceBuffer());
addForcesKernel->addArg(cc.getAtomIndexArray());
forceGroupFlag = (1<<force.getForceGroup());
if (cc.getNumContexts() == 1) {
cc.addPreComputation(new StartCalculationPreComputation(*this));
cc.addPostComputation(new AddForcesPostComputation(*this));
}
}
double CommonCalcPythonForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
if (cc.getNumContexts() == 1) {
// This method does nothing. The actual calculation is started by the pre-computation, continued on
// the worker thread, and finished by the post-computation.
return 0;
}
// When using multiple GPUs, this method is itself called from the worker thread.
// Submitting additional tasks and waiting for them to complete would lead to
// a deadlock.
if (cc.getContextIndex() != 0)
return 0.0;
contextImpl.getPositions(positionsVec);
executeOnWorkerThread(includeForces);
return addForces(includeForces, includeEnergy, -1);
}
void CommonCalcPythonForceKernel::beginComputation(bool includeForces, bool includeEnergy, int groups) {
if ((groups&forceGroupFlag) == 0)
return;
// The actual force computation will be done on a different thread.
cc.getWorkThread().addTask(new ExecuteTask(*this, includeForces));
}
void CommonCalcPythonForceKernel::executeOnWorkerThread(bool includeForces) {
contextImpl.getPositions(positionsVec, usePeriodic || !cc.getNonbondedUtilities().getUsePeriodic());
State::StateBuilder builder(contextImpl.getTime(), contextImpl.getStepCount());
builder.setPositions(positionsVec);
builder.setParameters(contextImpl.getParameters());
if (usePeriodic) {
Vec3 a, b, c;
contextImpl.getPeriodicBoxVectors(a, b, c);
builder.setPeriodicBoxVectors(a, b, c);
}
State state = builder.getState();
computation->compute(state, energy, forcesVec.data(), cc.getUseDoublePrecision());
if (includeForces) {
ContextSelector selector(cc);
forcesArray.upload(forcesVec.data());
}
}
double CommonCalcPythonForceKernel::addForces(bool includeForces, bool includeEnergy, int groups) {
if ((groups&forceGroupFlag) == 0)
return 0;
// Wait until executeOnWorkerThread() is finished.
if (cc.getNumContexts() == 1)
cc.getWorkThread().flush();
// Add in the forces.
if (includeForces) {
ContextSelector selector(cc);
addForcesKernel->execute(cc.getNumAtoms());
}
// Return the energy.
return energy;
}
...@@ -125,6 +125,8 @@ KernelImpl* CudaKernelFactory::createKernelImpl(std::string name, const Platform ...@@ -125,6 +125,8 @@ KernelImpl* CudaKernelFactory::createKernelImpl(std::string name, const Platform
return new CommonCalcCustomCPPForceKernel(name, platform, context, cu); return new CommonCalcCustomCPPForceKernel(name, platform, context, cu);
if (name == CalcOrientationRestraintForceKernel::Name()) if (name == CalcOrientationRestraintForceKernel::Name())
return new CommonCalcOrientationRestraintForceKernel(name, platform, cu); return new CommonCalcOrientationRestraintForceKernel(name, platform, cu);
if (name == CalcPythonForceKernel::Name())
return new CommonCalcPythonForceKernel(name, platform, context, cu);
if (name == CalcRGForceKernel::Name()) if (name == CalcRGForceKernel::Name())
return new CommonCalcRGForceKernel(name, platform, cu); return new CommonCalcRGForceKernel(name, platform, cu);
if (name == CalcRMSDForceKernel::Name()) if (name == CalcRMSDForceKernel::Name())
......
...@@ -94,6 +94,7 @@ CudaPlatform::CudaPlatform() { ...@@ -94,6 +94,7 @@ CudaPlatform::CudaPlatform() {
registerKernelFactory(CalcCustomCVForceKernel::Name(), factory); registerKernelFactory(CalcCustomCVForceKernel::Name(), factory);
registerKernelFactory(CalcATMForceKernel::Name(), factory); registerKernelFactory(CalcATMForceKernel::Name(), factory);
registerKernelFactory(CalcOrientationRestraintForceKernel::Name(), factory); registerKernelFactory(CalcOrientationRestraintForceKernel::Name(), factory);
registerKernelFactory(CalcPythonForceKernel::Name(), factory);
registerKernelFactory(CalcRGForceKernel::Name(), factory); registerKernelFactory(CalcRGForceKernel::Name(), factory);
registerKernelFactory(CalcRMSDForceKernel::Name(), factory); registerKernelFactory(CalcRMSDForceKernel::Name(), factory);
registerKernelFactory(CalcCustomManyParticleForceKernel::Name(), factory); registerKernelFactory(CalcCustomManyParticleForceKernel::Name(), factory);
......
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. *
* *
* Portions copyright (c) 2025 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#include "CudaTests.h"
#include "TestPythonForce.h"
void runPlatformTests() {
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment