Unverified Commit 2fbed592 authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Created PythonForce (#5122)

* Initial implementation of PythonForce

* Continuing implementation of PythonForce

* Tests for PythonForce

* Fix memory leaks

* Documentation for PythonForce

* Fixed incorrect return type

* Fix compilation error on Python older than 3.12

* Handle all dtypes

* Optimizations to PythonForce

* Optimized getPositions()

* Test all platforms

* Fix test failures
parent 74912095
......@@ -124,6 +124,8 @@ KernelImpl* HipKernelFactory::createKernelImpl(std::string name, const Platform&
return new CommonCalcCustomCPPForceKernel(name, platform, context, cu);
if (name == CalcOrientationRestraintForceKernel::Name())
return new CommonCalcOrientationRestraintForceKernel(name, platform, cu);
if (name == CalcPythonForceKernel::Name())
return new CommonCalcPythonForceKernel(name, platform, context, cu);
if (name == CalcRGForceKernel::Name())
return new CommonCalcRGForceKernel(name, platform, cu);
if (name == CalcRMSDForceKernel::Name())
......
......@@ -94,6 +94,7 @@ HipPlatform::HipPlatform() {
registerKernelFactory(CalcCustomCPPForceKernel::Name(), factory);
registerKernelFactory(CalcCustomCVForceKernel::Name(), factory);
registerKernelFactory(CalcOrientationRestraintForceKernel::Name(), factory);
registerKernelFactory(CalcPythonForceKernel::Name(), factory);
registerKernelFactory(CalcRGForceKernel::Name(), factory);
registerKernelFactory(CalcRMSDForceKernel::Name(), factory);
registerKernelFactory(CalcCustomManyParticleForceKernel::Name(), factory);
......
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. *
* *
* Portions copyright (c) 2025 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#include "HipTests.h"
#include "TestPythonForce.h"
void runPlatformTests() {
}
......@@ -123,6 +123,8 @@ KernelImpl* OpenCLKernelFactory::createKernelImpl(std::string name, const Platfo
return new CommonCalcCustomCPPForceKernel(name, platform, context, cl);
if (name == CalcOrientationRestraintForceKernel::Name())
return new CommonCalcOrientationRestraintForceKernel(name, platform, cl);
if (name == CalcPythonForceKernel::Name())
return new CommonCalcPythonForceKernel(name, platform, context, cl);
if (name == CalcRGForceKernel::Name())
return new CommonCalcRGForceKernel(name, platform, cl);
if (name == CalcRMSDForceKernel::Name())
......
......@@ -85,6 +85,7 @@ OpenCLPlatform::OpenCLPlatform() {
registerKernelFactory(CalcCustomCVForceKernel::Name(), factory);
registerKernelFactory(CalcATMForceKernel::Name(), factory);
registerKernelFactory(CalcOrientationRestraintForceKernel::Name(), factory);
registerKernelFactory(CalcPythonForceKernel::Name(), factory);
registerKernelFactory(CalcRGForceKernel::Name(), factory);
registerKernelFactory(CalcRMSDForceKernel::Name(), factory);
registerKernelFactory(CalcCustomManyParticleForceKernel::Name(), factory);
......
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. *
* *
* Portions copyright (c) 2025 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#include "OpenCLTests.h"
#include "TestPythonForce.h"
void runPlatformTests() {
}
......@@ -155,9 +155,12 @@ public:
/**
* Get the positions of all particles.
*
* @param context the context in which to execute this kernel
* @param positions on exit, this contains the particle positions
* @param allowPeriodic if true, the returned positions might be translated into a
* different periodic box to keep them closer to the origin
*/
void getPositions(ContextImpl& context, std::vector<Vec3>& positions);
void getPositions(ContextImpl& context, std::vector<Vec3>& positions, bool allowPeriodic=false);
/**
* Set the positions of all particles.
*
......@@ -1944,6 +1947,35 @@ private:
std::vector<Vec3> forces;
};
/**
* This kernel is invoked by PythonForceImpl to calculate the forces acting on the system and the energy of the system.
*/
class ReferenceCalcPythonForceKernel : public CalcPythonForceKernel {
public:
ReferenceCalcPythonForceKernel(std::string name, const Platform& platform) : CalcPythonForceKernel(name, platform) {
}
/**
* Initialize the kernel.
*
* @param system the System this kernel will be applied to
* @param force the PythonForce this kernel will be used for
*/
void initialize(const System& system, const PythonForce& force);
/**
* Execute the kernel to calculate the forces and/or energy.
*
* @param context the context in which to execute this kernel
* @param includeForces true if forces should be calculated
* @param includeEnergy true if the energy should be calculated
* @return the potential energy due to the force
*/
double execute(ContextImpl& context, bool includeForces, bool includeEnergy);
private:
const PythonForceComputation* computation;
std::vector<Vec3> forces;
bool usePeriodic;
};
} // namespace OpenMM
#endif /*OPENMM_REFERENCEKERNELS_H_*/
......@@ -4,7 +4,7 @@
* This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. *
* *
* Portions copyright (c) 2008-2024 Stanford University and the Authors. *
* Portions copyright (c) 2008-2025 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -86,6 +86,8 @@ KernelImpl* ReferenceKernelFactory::createKernelImpl(std::string name, const Pla
return new ReferenceCalcCustomCPPForceKernel(name, platform);
if (name == CalcOrientationRestraintForceKernel::Name())
return new ReferenceCalcOrientationRestraintForceKernel(name, platform);
if (name == CalcPythonForceKernel::Name())
return new ReferenceCalcPythonForceKernel(name, platform);
if (name == CalcRGForceKernel::Name())
return new ReferenceCalcRGForceKernel(name, platform);
if (name == CalcRMSDForceKernel::Name())
......
......@@ -213,7 +213,7 @@ void ReferenceUpdateStateDataKernel::setStepCount(const ContextImpl& context, lo
data.stepCount = count;
}
void ReferenceUpdateStateDataKernel::getPositions(ContextImpl& context, std::vector<Vec3>& positions) {
void ReferenceUpdateStateDataKernel::getPositions(ContextImpl& context, std::vector<Vec3>& positions, bool allowPeriodic) {
positions = extractPositions(context);
}
......@@ -3486,3 +3486,29 @@ double ReferenceCalcCustomCPPForceKernel::execute(ContextImpl& context, bool inc
forceData[i] += forces[i];
return energy;
}
void ReferenceCalcPythonForceKernel::initialize(const System& system, const PythonForce& force) {
computation = &force.getComputation();
forces.resize(system.getNumParticles());
usePeriodic = force.usesPeriodicBoundaryConditions();
}
double ReferenceCalcPythonForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
vector<Vec3>& posData = extractPositions(context);
vector<Vec3>& forceData = extractForces(context);
State::StateBuilder builder(context.getTime(), context.getStepCount());
builder.setPositions(posData);
builder.setParameters(context.getParameters());
if (usePeriodic) {
Vec3 a, b, c;
context.getPeriodicBoxVectors(a, b, c);
builder.setPeriodicBoxVectors(a, b, c);
}
double energy;
State state = builder.getState();
computation->compute(state, energy, forces.data(), true);
if (includeForces)
for (int i = 0; i < forces.size(); i++)
forceData[i] += forces[i];
return energy;
}
......@@ -65,6 +65,7 @@ ReferencePlatform::ReferencePlatform() {
registerKernelFactory(CalcCustomCVForceKernel::Name(), factory);
registerKernelFactory(CalcATMForceKernel::Name(), factory);
registerKernelFactory(CalcOrientationRestraintForceKernel::Name(), factory);
registerKernelFactory(CalcPythonForceKernel::Name(), factory);
registerKernelFactory(CalcRGForceKernel::Name(), factory);
registerKernelFactory(CalcRMSDForceKernel::Name(), factory);
registerKernelFactory(CalcCustomManyParticleForceKernel::Name(), factory);
......
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. *
* *
* Portions copyright (c) 2025 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#include "ReferenceTests.h"
#include "TestPythonForce.h"
void runPlatformTests() {
}
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. *
* *
* Portions copyright (c) 2025 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#include "openmm/internal/AssertionUtilities.h"
#include "openmm/Context.h"
#include "openmm/PythonForce.h"
#include "openmm/Platform.h"
#include "openmm/VerletIntegrator.h"
#include "sfmt/SFMT.h"
#include <iostream>
using namespace OpenMM;
using namespace std;
void testForce() {
class Computation : public PythonForceComputation {
void compute(const State& state, double& energy, void* forces, bool forcesAreDouble) const {
ASSERT_EQUAL(5.0, state.getParameters().at("a"));
ASSERT_EQUAL(10.0, state.getParameters().at("b"));
Vec3 a, b, c;
state.getPeriodicBoxVectors(a, b, c);
ASSERT_EQUAL(Vec3(2, 0, 0), a);
ASSERT_EQUAL(Vec3(0.1, 2, 0), b);
ASSERT_EQUAL(Vec3(0.1, 0.1, 2), c);
energy = 25.0;
int numParticles = state.getPositions().size();
for (int i = 0; i < numParticles; i++) {
Vec3 f = state.getPositions()[i]*2;
if (forcesAreDouble)
((Vec3*) forces)[i] = f;
else {
((float*) forces)[3*i] = (float) f[0];
((float*) forces)[3*i+1] = (float) f[1];
((float*) forces)[3*i+2] = (float) f[2];
}
}
}
};
int numParticles = 5;
System system;
Vec3 a(2, 0, 0);
Vec3 b(0.1, 2, 0);
Vec3 c(0.1, 0.1, 2);
system.setDefaultPeriodicBoxVectors(a, b, c);
vector<Vec3> positions;
OpenMM_SFMT::SFMT sfmt;
init_gen_rand(0, sfmt);
for (int i = 0; i < numParticles; i++) {
system.addParticle(1.0);
positions.push_back(Vec3(genrand_real2(sfmt), genrand_real2(sfmt), genrand_real2(sfmt)));
}
map<string, double> params;
params["a"] = 5.0;
params["b"] = 10.0;
PythonForce* force = new PythonForce(new Computation(), params);
ASSERT(!force->usesPeriodicBoundaryConditions());
force->setUsesPeriodicBoundaryConditions(true);
ASSERT(force->usesPeriodicBoundaryConditions());
system.addForce(force);
VerletIntegrator integrator(0.01);
Context context(system, integrator, platform);
context.setPositions(positions);
State state = context.getState(State::Energy | State::Forces);
ASSERT_EQUAL_TOL(25.0, state.getPotentialEnergy(), 1e-6);
for (int i = 0; i < numParticles; i++)
ASSERT_EQUAL_VEC(2*positions[i], state.getForces()[i], 1e-6)
// Check that force groups are handled correctly.
ASSERT_EQUAL_TOL(25.0, context.getState(State::Energy, false, 1).getPotentialEnergy(), 1e-6);
ASSERT_EQUAL_TOL(0.0, context.getState(State::Energy, false, 2).getPotentialEnergy(), 1e-6);
}
void runPlatformTests();
int main(int argc, char* argv[]) {
try {
initializeTests(argc, argv);
testForce();
runPlatformTests();
}
catch(const exception& e) {
cout << "exception: " << e.what() << endl;
return 1;
}
cout << "Done" << endl;
return 0;
}
......@@ -65,7 +65,8 @@ class WrapperGenerator:
"""This is the parent class of generators for various API wrapper files. It defines functions common to all of them."""
def __init__(self, inputDirname, output):
self.skipClasses = ['OpenMM::Vec3', 'OpenMM::XmlSerializer', 'OpenMM::Kernel', 'OpenMM::KernelImpl', 'OpenMM::KernelFactory', 'OpenMM::ContextImpl', 'OpenMM::SerializationNode', 'OpenMM::SerializationProxy']
self.skipClasses = ['OpenMM::Vec3', 'OpenMM::XmlSerializer', 'OpenMM::Kernel', 'OpenMM::KernelImpl', 'OpenMM::KernelFactory',
'OpenMM::ContextImpl', 'OpenMM::SerializationNode', 'OpenMM::SerializationProxy', 'OpenMM::PythonForce']
self.skipMethods = ['State OpenMM::Context::getState',
'void OpenMM::Context::createCheckpoint',
'void OpenMM::Context::loadCheckpoint',
......
......@@ -39,3 +39,5 @@ __version__ = Platform.getOpenMMVersion()
class OpenMMException(Exception):
"""This is the class used for all exceptions thrown by the C++ library."""
pass
registerPythonForceProxy()
\ No newline at end of file
......@@ -99,6 +99,7 @@ SKIP_METHODS = [('State', 'getPositions'),
('XmlSerializer', 'deserialize'),
("NoseHooverIntegrator", "getAllThermostatedIndividualParticles"),
("NoseHooverIntegrator", "getAllThermostatedPairs"),
("PythonForce", "PythonForce"),
]
......@@ -175,6 +176,7 @@ UNITS = {
("*", "setDefaultPressureZ") : (None, ("unit.bar",)),
("*", "getDefaultSurfaceTension") : ("unit.bar*unit.nanometer", ()),
("*", "setDefaultSurfaceTension") : (None, ("unit.bar*unit.nanometer",)),
("*", "computeCurrentPressure") : ("unit.bar", ()),
("*", "getDefaultTemperature") : ("unit.kelvin", ()),
("*", "setDefaultTemperature") : (None, ("unit.kelvin",)),
("*", "getRelativeTemperature") : ("unit.kelvin", ()),
......@@ -493,7 +495,6 @@ UNITS = {
("MonteCarloMembraneBarostat", "MonteCarloMembraneBarostat") : (None, ("unit.bar", "unit.bar*unit.nanometer", "unit.kelvin", None, None, None)),
("MonteCarloMembraneBarostat", "getXYMode") : (None, ()),
("MonteCarloMembraneBarostat", "getZMode") : (None, ()),
("*", "computeCurrentPressure") : ("unit.bar", ()),
("CustomIntegrator", "CustomIntegrator") : (None, ("unit.picosecond",)),
("BrownianIntegrator", "BrownianIntegrator") : (None, ("unit.kelvin", "unit.picosecond**-1", "unit.picosecond")),
("LangevinIntegrator", "LangevinIntegrator") : (None, ("unit.kelvin", "unit.picosecond**-1", "unit.picosecond")),
......@@ -571,4 +572,7 @@ UNITS = {
("ATMForce", "getParticleTransformation") : (None, ()),
("FixedDisplacement", "getFixedDisplacement1") : ("unit.nanometer", ()),
("FixedDisplacement", "getFixedDisplacement0") : ("unit.nanometer", ()),
("PythonForce", "getComputation") : (None, ()),
("PythonForce", "getGlobalParameters") : (None, ()),
("PythonForce", "getPickledFunction") : (None, ()),
}
......@@ -9,5 +9,6 @@
%include pythonprepend.i
%include pythonappend.i
%include typemaps.i
%include pythonforce.i
%feature("director") OpenMM::MinimizationReporter;
%newobject OpenMM::PythonForce::PythonForce;
%inline %{
#include <iomanip>
namespace OpenMM {
/**
* This is the PythonForceComputation that performs the computation for a PythonForce. It invokes the function
* provided by the user, validates the outputs, and converts them to the required format.
*/
class ComputationWrapper : public PythonForceComputation {
public:
ComputationWrapper(PyObject* computation) : computation(computation) {
}
void compute(const State& state, double& energy, void* forces, bool forcesAreDouble) const {
PyGILState_STATE gstate;
gstate = PyGILState_Ensure();
// Invoke the function.
swig_type_info* info = SWIGTYPE_p_OpenMM__State;
PyObject* wrappedState = SWIG_NewPointerObj((void*) &state, info, 0);
PyObject* result = PyObject_CallFunctionObjArgs(computation, wrappedState, NULL);
if (result == NULL) {
// The function raised an exception. Convert it to an OpenMMException.
#if PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION < 12
PyObject *type;
PyObject *exception;
PyObject *traceback;
PyErr_Fetch(&type, &exception, &traceback);
#else
PyObject *exception = PyErr_GetRaisedException();
#endif
PyObject *message = PyObject_Str(exception);
std::string *ptr;
SWIG_AsPtr_std_string(message, &ptr);
Py_XDECREF(message);
PyGILState_Release(gstate);
throw OpenMMException(*ptr);
}
// Extract the return values.
if (!PyTuple_Check(result) || PyTuple_Size(result) != 2) {
PyGILState_Release(gstate);
throw OpenMMException("PythonForce: Expected two return values");
}
PyObject* pyenergy = Py_StripOpenMMUnits(PyTuple_GetItem(result, 0));
PyObject* pyforces = Py_StripOpenMMUnits(PyTuple_GetItem(result, 1));
energy = PyFloat_AsDouble(pyenergy);
// Copy the forces to the output vector.
if (!PyArray_Check(pyforces) || PyArray_NDIM((PyArrayObject*) pyforces) != 2) {
PyGILState_Release(gstate);
throw OpenMMException("PythonForce: The forces must be returned in a 2-dimensional NumPy array");
}
npy_intp* dims = PyArray_DIMS((PyArrayObject*) pyforces);
int numParticles = state.getPositions().size();
if (dims[0] != numParticles || dims[1] != 3) {
PyGILState_Release(gstate);
throw OpenMMException("PythonForce: The forces must be returned in a NumPy array of shape (# particles, 3)");
}
PyObject* array;
int targetType = (forcesAreDouble ? NPY_DOUBLE : NPY_FLOAT);
if (PyArray_CHKFLAGS((PyArrayObject*) pyforces, NPY_ARRAY_CARRAY_RO) && PyArray_DESCR((PyArrayObject*) pyforces)->type_num == targetType)
array = pyforces;
else
array = PyArray_FromAny(pyforces, PyArray_DescrFromType(targetType), 2, 2, NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_FORCECAST, NULL);
int elementSize = (forcesAreDouble ? sizeof(double) : sizeof(float));
void* data = PyArray_DATA((PyArrayObject*) array);
memcpy(forces, data, 3*elementSize*numParticles);
Py_XDECREF(wrappedState);
Py_XDECREF(result);
Py_XDECREF(pyenergy);
Py_XDECREF(pyforces);
if (array != pyforces)
Py_XDECREF(array);
PyGILState_Release(gstate);
}
private:
PyObject* computation;
};
/**
* Construct a new PythonForce.
*/
PythonForce* _createPythonForce(PyObject* computation, const std::map<std::string, double>& globalParameters={}) {
PythonForce* force = new PythonForce(new ComputationWrapper(computation), globalParameters);
PyObject* pickle = PyImport_ImportModule("pickle");
PyObject* dumps = PyUnicode_FromString("dumps");
PyObject* result = PyObject_CallMethodOneArg(pickle, dumps, computation);
if (result == NULL) {
// It couldn't be pickled. It will still work, but can't be serialized. Clear the error flag.
PyErr_Clear();
}
else {
char* buffer;
Py_ssize_t len;
if (PyBytes_AsStringAndSize(result, &buffer, &len) == 0)
force->setPickledFunction(buffer, len);
}
return force;
}
/**
* This is the serialization proxy used to serialize PythonForce objects.
*/
class PythonForceProxy : public SerializationProxy {
public:
PythonForceProxy() : SerializationProxy("PythonForce") {
}
static std::string hexEncode(const std::vector<char>& input) {
std::stringstream ss;
ss << std::hex << std::setfill('0');
for (unsigned char i : input)
ss << std::setw(2) << static_cast<uint64_t>(i);
return ss.str();
}
static std::vector<char> hexDecode(const std::string& input) {
std::vector<char> res;
res.reserve(input.size() / 2);
for (size_t i = 0; i < input.length(); i += 2) {
std::istringstream iss(input.substr(i, 2));
uint64_t temp;
iss >> std::hex >> temp;
res.push_back(static_cast<unsigned char>(temp));
}
return res;
}
void serialize(const void* object, SerializationNode& node) const {
node.setIntProperty("version", 1);
const PythonForce& force = *reinterpret_cast<const PythonForce*>(object);
if (force.getPickledFunction().size() == 0)
throw OpenMMException("PythonForceProxy: Could not serialize PythonForce because its function could not be pickled.");
node.setStringProperty("function", hexEncode(force.getPickledFunction()));
node.setIntProperty("forceGroup", force.getForceGroup());
node.setBoolProperty("usesPeriodic", force.usesPeriodicBoundaryConditions());
SerializationNode& globalParams = node.createChildNode("GlobalParameters");
for (auto param : force.getGlobalParameters())
globalParams.createChildNode("Parameter").setStringProperty("name", param.first).setDoubleProperty("default", param.second);
}
void* deserialize(const SerializationNode& node) const {
int version = node.getIntProperty("version");
if (version != 1)
throw OpenMMException("Unsupported version number");
std::vector<char> pickledFunction = hexDecode(node.getStringProperty("function"));
PyObject* pickle = PyImport_ImportModule("pickle");
PyObject* loads = PyUnicode_FromString("loads");
PyObject *pythonBytes = PyBytes_FromStringAndSize(pickledFunction.data(), pickledFunction.size());
PyObject *function = PyObject_CallMethodOneArg(pickle, loads, pythonBytes);
Py_XDECREF(pythonBytes);
const SerializationNode& paramsNode = node.getChildNode("GlobalParameters");
std::map<std::string, double> params;
for (auto& parameter : paramsNode.getChildren())
params[parameter.getStringProperty("name")] = parameter.getDoubleProperty("default");
PythonForce* force = _createPythonForce(function, params);
if (node.hasProperty("forceGroup"))
force->setForceGroup(node.getIntProperty("forceGroup", 0));
if (node.hasProperty("usesPeriodic"))
force->setUsesPeriodicBoundaryConditions(node.getBoolProperty("usesPeriodic"));
return force;
}
};
/**
* Register the serialization proxy. This function is invoked automatically when the openmm module is imported.
*/
void registerPythonForceProxy() {
SerializationProxy::registerProxy(typeid(PythonForce), new PythonForceProxy());
}
}
%}
%extend OpenMM::PythonForce {
%feature("docstring") PythonForce "Create a PythonForce.
Parameters
----------
computation : function
A function that performs the computation. It should take a State as its argument
and return two values: the potential energy (a scalar) and the forces (a NumPy array).
globalParameters : dict
Any global parameters the function depends on. Keys are the parameter names, and the
corresponding values are their default values.
"
PythonForce(PyObject* computation, const std::map<std::string, double>& globalParameters={}) {
return _createPythonForce(computation, globalParameters);
}
}
......@@ -18,7 +18,7 @@ PyObject* Vec3_to_PyVec3(const OpenMM::Vec3& v) {
}
}
%fragment("Py_StripOpenMMUnits", "header") {
%header {
/**
* Strip any OpenMM units of an input PyObject.
......@@ -112,7 +112,7 @@ PyObject* Py_StripOpenMMUnits(PyObject *input) {
}
%fragment("Py_SequenceToVec3", "header", fragment="Py_StripOpenMMUnits") {
%fragment("Py_SequenceToVec3", "header") {
OpenMM::Vec3 Py_SequenceToVec3(PyObject* obj, int& status) {
PyObject* s, *o, *o1;
double x[3];
......@@ -154,7 +154,7 @@ OpenMM::Vec3 Py_SequenceToVec3(PyObject* obj, int& status) {
}
}
%fragment("Py_SequenceToVecDouble", "header", fragment="Py_StripOpenMMUnits") {
%fragment("Py_SequenceToVecDouble", "header") {
int Py_SequenceToVecDouble(PyObject* obj, std::vector<double>& out) {
PyObject* stripped = Py_StripOpenMMUnits(obj);
PyObject* item = NULL;
......@@ -396,14 +396,14 @@ int Py_SequenceToVecVecVecDouble(PyObject* obj, std::vector<std::vector<std::vec
// ------ typemap for double ----
%typemap(typecheck, precedence=SWIG_TYPECHECK_DOUBLE, fragment="Py_StripOpenMMUnits") double {
%typemap(typecheck, precedence=SWIG_TYPECHECK_DOUBLE) double {
double argp = 0;
PyObject* s = NULL;
s = Py_StripOpenMMUnits($input);
$1 = (s != NULL) ? SWIG_IsOK(SWIG_AsVal_double(s, &argp)) : 0;
Py_DECREF(s);
}
%typemap(in, noblock=1, fragment="Py_StripOpenMMUnits") double (double argp = 0, int res = 0,
%typemap(in, noblock=1) double (double argp = 0, int res = 0,
PyObject* stripped = NULL) {
stripped = Py_StripOpenMMUnits($input);
......
import unittest
from openmm import *
from openmm.unit import *
import numpy as np
import copy
def compute(state):
"""This is a computation function used by the test cases."""
pos = state.getPositions(asNumpy=True).value_in_unit(nanometer)
k = state.getParameters()['k']
energy = k*np.sum(pos*pos)
force = -0.5*k*pos
return energy*kilojoules_per_mole, force*kilojoules_per_mole/nanometer
class TestPythonForce(unittest.TestCase):
"""Test the PythonForce class"""
def testComputeForce(self):
"""Test using PythonForce to compute forces."""
system = System()
for i in range(5):
system.addParticle(1.0)
force = PythonForce(compute, {'k':2.5})
system.addForce(force)
positions = np.random.rand(5, 3)
for i in range(Platform.getNumPlatforms()):
integrator = VerletIntegrator(0.001)
try:
context = Context(system, integrator, Platform.getPlatform(i))
except OpenMMException:
if i == 0:
raise
else:
# This happens on CI when no GPU is available.
continue
context.setPositions(positions)
state = context.getState(energy=True, forces=True)
self.assertAlmostEqual(2.5*np.sum(positions*positions), state.getPotentialEnergy().value_in_unit(kilojoules_per_mole), places=5)
self.assertTrue(np.allclose(-1.25*positions, state.getForces(asNumpy=True).value_in_unit(kilojoules_per_mole/nanometer)))
def testExceptions(self):
"""Test that PythonForce handles exceptions correctly."""
def compute2(state):
raise ValueError('This should fail')
system = System()
system.addParticle(1.0)
force = PythonForce(compute2)
system.addForce(force)
positions = np.random.rand(1, 3)
for i in range(Platform.getNumPlatforms()):
integrator = VerletIntegrator(0.001)
try:
context = Context(system, integrator, Platform.getPlatform(i))
except OpenMMException:
if i == 0:
raise
else:
# This happens on CI when no GPU is available.
continue
context.setPositions(positions)
with self.assertRaises(OpenMMException) as cm:
context.getState(energy=True)
self.assertEqual('This should fail', str(cm.exception))
def testSerialize(self):
"""Test that PythonForce can be serialized."""
force1 = PythonForce(compute, {'k':2.5})
force1.setUsesPeriodicBoundaryConditions(True)
# Make a copy by serializing and the deserializing it.
force2 = copy.copy(force1)
# They should be identical.
self.assertEqual(XmlSerializer.serialize(force1), XmlSerializer.serialize(force2))
self.assertEqual(dict(force2.getGlobalParameters()), {'k':2.5})
self.assertTrue(force2.usesPeriodicBoundaryConditions())
# A locally defined function cannot be pickled. We should not be able to serialize a force
# that uses it.
def compute2(state):
return 1.0, np.zeros(len(state.getPositions()), 3)
force3 = PythonForce(compute2)
with self.assertRaises(OpenMMException):
XmlSerializer.serialize(force3)
def testMinimization(self):
"""Test that PythonForce works correctly with the minimizer."""
system = System()
for i in range(5):
system.addParticle(1.0)
force = PythonForce(compute, {'k':2.5})
system.addForce(force)
positions = np.random.rand(5, 3)
integrator = VerletIntegrator(0.001)
context = Context(system, integrator, Platform.getPlatform('Reference'))
context.setPositions(positions)
# The PythonForce and the MinimizationReporter both involve calling back into Python code,
# possibly from different threads. Make sure it doesn't cause any problems.
class Reporter(MinimizationReporter):
count = 0
def report(self, iteration, x, grad, args):
self.count += 1
return False
reporter = Reporter()
LocalEnergyMinimizer.minimize(context, tolerance=1e-3, reporter=reporter)
self.assertTrue(reporter.count > 0)
state = context.getState(energy=True, positions=True)
self.assertAlmostEqual(0.0, state.getPotentialEnergy().value_in_unit(kilojoules_per_mole))
def testMemory(self):
"""Test for memory leaks in the Python/C++ interface."""
try:
import resource
except:
# The resource module is not available on Windows.
return
system = System()
for i in range(1000):
system.addParticle(1.0)
force = PythonForce(compute, {'k':2.5})
system.addForce(force)
positions = np.random.rand(1000, 3)
integrator = VerletIntegrator(0.001)
context = Context(system, integrator, Platform.getPlatform('Reference'))
context.setPositions(positions)
integrator.step(5000)
memory1 = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
integrator.step(5000)
memory2 = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
self.assertTrue(memory2 < 1.05*memory1)
def testDtypes(self):
"""Test returning forces with different types."""
for dtype in [np.float32, np.float64, int]:
def compute2(state):
return 0, np.array([[1,2,3],[4,5,6]], dtype=dtype)
system = System()
system.addParticle(1.0)
system.addParticle(1.0)
force = PythonForce(compute2)
system.addForce(force)
positions = np.random.rand(2, 3)
for i in range(Platform.getNumPlatforms()):
integrator = VerletIntegrator(0.001)
try:
context = Context(system, integrator, Platform.getPlatform(i))
except OpenMMException:
if i == 0:
raise
else:
# This happens on CI when no GPU is available.
continue
context.setPositions(positions)
forces = context.getState(forces=True).getForces().value_in_unit(kilojoules_per_mole/nanometer)
self.assertEqual(Vec3(1,2,3), forces[0])
self.assertEqual(Vec3(4,5,6), forces[1])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment