Unverified Commit 2fbed592 authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Created PythonForce (#5122)

* Initial implementation of PythonForce

* Continuing implementation of PythonForce

* Tests for PythonForce

* Fix memory leaks

* Documentation for PythonForce

* Fixed incorrect return type

* Fix compilation error on Python older than 3.12

* Handle all dtypes

* Optimizations to PythonForce

* Optimized getPositions()

* Test all platforms

* Fix test failures
parent 74912095
...@@ -124,6 +124,8 @@ KernelImpl* HipKernelFactory::createKernelImpl(std::string name, const Platform& ...@@ -124,6 +124,8 @@ KernelImpl* HipKernelFactory::createKernelImpl(std::string name, const Platform&
return new CommonCalcCustomCPPForceKernel(name, platform, context, cu); return new CommonCalcCustomCPPForceKernel(name, platform, context, cu);
if (name == CalcOrientationRestraintForceKernel::Name()) if (name == CalcOrientationRestraintForceKernel::Name())
return new CommonCalcOrientationRestraintForceKernel(name, platform, cu); return new CommonCalcOrientationRestraintForceKernel(name, platform, cu);
if (name == CalcPythonForceKernel::Name())
return new CommonCalcPythonForceKernel(name, platform, context, cu);
if (name == CalcRGForceKernel::Name()) if (name == CalcRGForceKernel::Name())
return new CommonCalcRGForceKernel(name, platform, cu); return new CommonCalcRGForceKernel(name, platform, cu);
if (name == CalcRMSDForceKernel::Name()) if (name == CalcRMSDForceKernel::Name())
......
...@@ -94,6 +94,7 @@ HipPlatform::HipPlatform() { ...@@ -94,6 +94,7 @@ HipPlatform::HipPlatform() {
registerKernelFactory(CalcCustomCPPForceKernel::Name(), factory); registerKernelFactory(CalcCustomCPPForceKernel::Name(), factory);
registerKernelFactory(CalcCustomCVForceKernel::Name(), factory); registerKernelFactory(CalcCustomCVForceKernel::Name(), factory);
registerKernelFactory(CalcOrientationRestraintForceKernel::Name(), factory); registerKernelFactory(CalcOrientationRestraintForceKernel::Name(), factory);
registerKernelFactory(CalcPythonForceKernel::Name(), factory);
registerKernelFactory(CalcRGForceKernel::Name(), factory); registerKernelFactory(CalcRGForceKernel::Name(), factory);
registerKernelFactory(CalcRMSDForceKernel::Name(), factory); registerKernelFactory(CalcRMSDForceKernel::Name(), factory);
registerKernelFactory(CalcCustomManyParticleForceKernel::Name(), factory); registerKernelFactory(CalcCustomManyParticleForceKernel::Name(), factory);
......
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. *
* *
* Portions copyright (c) 2025 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#include "HipTests.h"
#include "TestPythonForce.h"
void runPlatformTests() {
}
...@@ -123,6 +123,8 @@ KernelImpl* OpenCLKernelFactory::createKernelImpl(std::string name, const Platfo ...@@ -123,6 +123,8 @@ KernelImpl* OpenCLKernelFactory::createKernelImpl(std::string name, const Platfo
return new CommonCalcCustomCPPForceKernel(name, platform, context, cl); return new CommonCalcCustomCPPForceKernel(name, platform, context, cl);
if (name == CalcOrientationRestraintForceKernel::Name()) if (name == CalcOrientationRestraintForceKernel::Name())
return new CommonCalcOrientationRestraintForceKernel(name, platform, cl); return new CommonCalcOrientationRestraintForceKernel(name, platform, cl);
if (name == CalcPythonForceKernel::Name())
return new CommonCalcPythonForceKernel(name, platform, context, cl);
if (name == CalcRGForceKernel::Name()) if (name == CalcRGForceKernel::Name())
return new CommonCalcRGForceKernel(name, platform, cl); return new CommonCalcRGForceKernel(name, platform, cl);
if (name == CalcRMSDForceKernel::Name()) if (name == CalcRMSDForceKernel::Name())
......
...@@ -85,6 +85,7 @@ OpenCLPlatform::OpenCLPlatform() { ...@@ -85,6 +85,7 @@ OpenCLPlatform::OpenCLPlatform() {
registerKernelFactory(CalcCustomCVForceKernel::Name(), factory); registerKernelFactory(CalcCustomCVForceKernel::Name(), factory);
registerKernelFactory(CalcATMForceKernel::Name(), factory); registerKernelFactory(CalcATMForceKernel::Name(), factory);
registerKernelFactory(CalcOrientationRestraintForceKernel::Name(), factory); registerKernelFactory(CalcOrientationRestraintForceKernel::Name(), factory);
registerKernelFactory(CalcPythonForceKernel::Name(), factory);
registerKernelFactory(CalcRGForceKernel::Name(), factory); registerKernelFactory(CalcRGForceKernel::Name(), factory);
registerKernelFactory(CalcRMSDForceKernel::Name(), factory); registerKernelFactory(CalcRMSDForceKernel::Name(), factory);
registerKernelFactory(CalcCustomManyParticleForceKernel::Name(), factory); registerKernelFactory(CalcCustomManyParticleForceKernel::Name(), factory);
......
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. *
* *
* Portions copyright (c) 2025 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#include "OpenCLTests.h"
#include "TestPythonForce.h"
void runPlatformTests() {
}
...@@ -155,9 +155,12 @@ public: ...@@ -155,9 +155,12 @@ public:
/** /**
* Get the positions of all particles. * Get the positions of all particles.
* *
* @param context the context in which to execute this kernel
* @param positions on exit, this contains the particle positions * @param positions on exit, this contains the particle positions
* @param allowPeriodic if true, the returned positions might be translated into a
* different periodic box to keep them closer to the origin
*/ */
void getPositions(ContextImpl& context, std::vector<Vec3>& positions); void getPositions(ContextImpl& context, std::vector<Vec3>& positions, bool allowPeriodic=false);
/** /**
* Set the positions of all particles. * Set the positions of all particles.
* *
...@@ -1944,6 +1947,35 @@ private: ...@@ -1944,6 +1947,35 @@ private:
std::vector<Vec3> forces; std::vector<Vec3> forces;
}; };
/**
* This kernel is invoked by PythonForceImpl to calculate the forces acting on the system and the energy of the system.
*/
class ReferenceCalcPythonForceKernel : public CalcPythonForceKernel {
public:
ReferenceCalcPythonForceKernel(std::string name, const Platform& platform) : CalcPythonForceKernel(name, platform) {
}
/**
* Initialize the kernel.
*
* @param system the System this kernel will be applied to
* @param force the PythonForce this kernel will be used for
*/
void initialize(const System& system, const PythonForce& force);
/**
* Execute the kernel to calculate the forces and/or energy.
*
* @param context the context in which to execute this kernel
* @param includeForces true if forces should be calculated
* @param includeEnergy true if the energy should be calculated
* @return the potential energy due to the force
*/
double execute(ContextImpl& context, bool includeForces, bool includeEnergy);
private:
const PythonForceComputation* computation;
std::vector<Vec3> forces;
bool usePeriodic;
};
} // namespace OpenMM } // namespace OpenMM
#endif /*OPENMM_REFERENCEKERNELS_H_*/ #endif /*OPENMM_REFERENCEKERNELS_H_*/
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* This is part of the OpenMM molecular simulation toolkit. * * This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. * * See https://openmm.org/development. *
* * * *
* Portions copyright (c) 2008-2024 Stanford University and the Authors. * * Portions copyright (c) 2008-2025 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -86,6 +86,8 @@ KernelImpl* ReferenceKernelFactory::createKernelImpl(std::string name, const Pla ...@@ -86,6 +86,8 @@ KernelImpl* ReferenceKernelFactory::createKernelImpl(std::string name, const Pla
return new ReferenceCalcCustomCPPForceKernel(name, platform); return new ReferenceCalcCustomCPPForceKernel(name, platform);
if (name == CalcOrientationRestraintForceKernel::Name()) if (name == CalcOrientationRestraintForceKernel::Name())
return new ReferenceCalcOrientationRestraintForceKernel(name, platform); return new ReferenceCalcOrientationRestraintForceKernel(name, platform);
if (name == CalcPythonForceKernel::Name())
return new ReferenceCalcPythonForceKernel(name, platform);
if (name == CalcRGForceKernel::Name()) if (name == CalcRGForceKernel::Name())
return new ReferenceCalcRGForceKernel(name, platform); return new ReferenceCalcRGForceKernel(name, platform);
if (name == CalcRMSDForceKernel::Name()) if (name == CalcRMSDForceKernel::Name())
......
...@@ -213,7 +213,7 @@ void ReferenceUpdateStateDataKernel::setStepCount(const ContextImpl& context, lo ...@@ -213,7 +213,7 @@ void ReferenceUpdateStateDataKernel::setStepCount(const ContextImpl& context, lo
data.stepCount = count; data.stepCount = count;
} }
void ReferenceUpdateStateDataKernel::getPositions(ContextImpl& context, std::vector<Vec3>& positions) { void ReferenceUpdateStateDataKernel::getPositions(ContextImpl& context, std::vector<Vec3>& positions, bool allowPeriodic) {
positions = extractPositions(context); positions = extractPositions(context);
} }
...@@ -3486,3 +3486,29 @@ double ReferenceCalcCustomCPPForceKernel::execute(ContextImpl& context, bool inc ...@@ -3486,3 +3486,29 @@ double ReferenceCalcCustomCPPForceKernel::execute(ContextImpl& context, bool inc
forceData[i] += forces[i]; forceData[i] += forces[i];
return energy; return energy;
} }
void ReferenceCalcPythonForceKernel::initialize(const System& system, const PythonForce& force) {
computation = &force.getComputation();
forces.resize(system.getNumParticles());
usePeriodic = force.usesPeriodicBoundaryConditions();
}
double ReferenceCalcPythonForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
vector<Vec3>& posData = extractPositions(context);
vector<Vec3>& forceData = extractForces(context);
State::StateBuilder builder(context.getTime(), context.getStepCount());
builder.setPositions(posData);
builder.setParameters(context.getParameters());
if (usePeriodic) {
Vec3 a, b, c;
context.getPeriodicBoxVectors(a, b, c);
builder.setPeriodicBoxVectors(a, b, c);
}
double energy;
State state = builder.getState();
computation->compute(state, energy, forces.data(), true);
if (includeForces)
for (int i = 0; i < forces.size(); i++)
forceData[i] += forces[i];
return energy;
}
...@@ -65,6 +65,7 @@ ReferencePlatform::ReferencePlatform() { ...@@ -65,6 +65,7 @@ ReferencePlatform::ReferencePlatform() {
registerKernelFactory(CalcCustomCVForceKernel::Name(), factory); registerKernelFactory(CalcCustomCVForceKernel::Name(), factory);
registerKernelFactory(CalcATMForceKernel::Name(), factory); registerKernelFactory(CalcATMForceKernel::Name(), factory);
registerKernelFactory(CalcOrientationRestraintForceKernel::Name(), factory); registerKernelFactory(CalcOrientationRestraintForceKernel::Name(), factory);
registerKernelFactory(CalcPythonForceKernel::Name(), factory);
registerKernelFactory(CalcRGForceKernel::Name(), factory); registerKernelFactory(CalcRGForceKernel::Name(), factory);
registerKernelFactory(CalcRMSDForceKernel::Name(), factory); registerKernelFactory(CalcRMSDForceKernel::Name(), factory);
registerKernelFactory(CalcCustomManyParticleForceKernel::Name(), factory); registerKernelFactory(CalcCustomManyParticleForceKernel::Name(), factory);
......
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. *
* *
* Portions copyright (c) 2025 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#include "ReferenceTests.h"
#include "TestPythonForce.h"
void runPlatformTests() {
}
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. *
* *
* Portions copyright (c) 2025 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#include "openmm/internal/AssertionUtilities.h"
#include "openmm/Context.h"
#include "openmm/PythonForce.h"
#include "openmm/Platform.h"
#include "openmm/VerletIntegrator.h"
#include "sfmt/SFMT.h"
#include <iostream>
using namespace OpenMM;
using namespace std;
void testForce() {
class Computation : public PythonForceComputation {
void compute(const State& state, double& energy, void* forces, bool forcesAreDouble) const {
ASSERT_EQUAL(5.0, state.getParameters().at("a"));
ASSERT_EQUAL(10.0, state.getParameters().at("b"));
Vec3 a, b, c;
state.getPeriodicBoxVectors(a, b, c);
ASSERT_EQUAL(Vec3(2, 0, 0), a);
ASSERT_EQUAL(Vec3(0.1, 2, 0), b);
ASSERT_EQUAL(Vec3(0.1, 0.1, 2), c);
energy = 25.0;
int numParticles = state.getPositions().size();
for (int i = 0; i < numParticles; i++) {
Vec3 f = state.getPositions()[i]*2;
if (forcesAreDouble)
((Vec3*) forces)[i] = f;
else {
((float*) forces)[3*i] = (float) f[0];
((float*) forces)[3*i+1] = (float) f[1];
((float*) forces)[3*i+2] = (float) f[2];
}
}
}
};
int numParticles = 5;
System system;
Vec3 a(2, 0, 0);
Vec3 b(0.1, 2, 0);
Vec3 c(0.1, 0.1, 2);
system.setDefaultPeriodicBoxVectors(a, b, c);
vector<Vec3> positions;
OpenMM_SFMT::SFMT sfmt;
init_gen_rand(0, sfmt);
for (int i = 0; i < numParticles; i++) {
system.addParticle(1.0);
positions.push_back(Vec3(genrand_real2(sfmt), genrand_real2(sfmt), genrand_real2(sfmt)));
}
map<string, double> params;
params["a"] = 5.0;
params["b"] = 10.0;
PythonForce* force = new PythonForce(new Computation(), params);
ASSERT(!force->usesPeriodicBoundaryConditions());
force->setUsesPeriodicBoundaryConditions(true);
ASSERT(force->usesPeriodicBoundaryConditions());
system.addForce(force);
VerletIntegrator integrator(0.01);
Context context(system, integrator, platform);
context.setPositions(positions);
State state = context.getState(State::Energy | State::Forces);
ASSERT_EQUAL_TOL(25.0, state.getPotentialEnergy(), 1e-6);
for (int i = 0; i < numParticles; i++)
ASSERT_EQUAL_VEC(2*positions[i], state.getForces()[i], 1e-6)
// Check that force groups are handled correctly.
ASSERT_EQUAL_TOL(25.0, context.getState(State::Energy, false, 1).getPotentialEnergy(), 1e-6);
ASSERT_EQUAL_TOL(0.0, context.getState(State::Energy, false, 2).getPotentialEnergy(), 1e-6);
}
void runPlatformTests();
int main(int argc, char* argv[]) {
try {
initializeTests(argc, argv);
testForce();
runPlatformTests();
}
catch(const exception& e) {
cout << "exception: " << e.what() << endl;
return 1;
}
cout << "Done" << endl;
return 0;
}
...@@ -65,7 +65,8 @@ class WrapperGenerator: ...@@ -65,7 +65,8 @@ class WrapperGenerator:
"""This is the parent class of generators for various API wrapper files. It defines functions common to all of them.""" """This is the parent class of generators for various API wrapper files. It defines functions common to all of them."""
def __init__(self, inputDirname, output): def __init__(self, inputDirname, output):
self.skipClasses = ['OpenMM::Vec3', 'OpenMM::XmlSerializer', 'OpenMM::Kernel', 'OpenMM::KernelImpl', 'OpenMM::KernelFactory', 'OpenMM::ContextImpl', 'OpenMM::SerializationNode', 'OpenMM::SerializationProxy'] self.skipClasses = ['OpenMM::Vec3', 'OpenMM::XmlSerializer', 'OpenMM::Kernel', 'OpenMM::KernelImpl', 'OpenMM::KernelFactory',
'OpenMM::ContextImpl', 'OpenMM::SerializationNode', 'OpenMM::SerializationProxy', 'OpenMM::PythonForce']
self.skipMethods = ['State OpenMM::Context::getState', self.skipMethods = ['State OpenMM::Context::getState',
'void OpenMM::Context::createCheckpoint', 'void OpenMM::Context::createCheckpoint',
'void OpenMM::Context::loadCheckpoint', 'void OpenMM::Context::loadCheckpoint',
......
...@@ -39,3 +39,5 @@ __version__ = Platform.getOpenMMVersion() ...@@ -39,3 +39,5 @@ __version__ = Platform.getOpenMMVersion()
class OpenMMException(Exception): class OpenMMException(Exception):
"""This is the class used for all exceptions thrown by the C++ library.""" """This is the class used for all exceptions thrown by the C++ library."""
pass pass
registerPythonForceProxy()
\ No newline at end of file
...@@ -99,6 +99,7 @@ SKIP_METHODS = [('State', 'getPositions'), ...@@ -99,6 +99,7 @@ SKIP_METHODS = [('State', 'getPositions'),
('XmlSerializer', 'deserialize'), ('XmlSerializer', 'deserialize'),
("NoseHooverIntegrator", "getAllThermostatedIndividualParticles"), ("NoseHooverIntegrator", "getAllThermostatedIndividualParticles"),
("NoseHooverIntegrator", "getAllThermostatedPairs"), ("NoseHooverIntegrator", "getAllThermostatedPairs"),
("PythonForce", "PythonForce"),
] ]
...@@ -175,6 +176,7 @@ UNITS = { ...@@ -175,6 +176,7 @@ UNITS = {
("*", "setDefaultPressureZ") : (None, ("unit.bar",)), ("*", "setDefaultPressureZ") : (None, ("unit.bar",)),
("*", "getDefaultSurfaceTension") : ("unit.bar*unit.nanometer", ()), ("*", "getDefaultSurfaceTension") : ("unit.bar*unit.nanometer", ()),
("*", "setDefaultSurfaceTension") : (None, ("unit.bar*unit.nanometer",)), ("*", "setDefaultSurfaceTension") : (None, ("unit.bar*unit.nanometer",)),
("*", "computeCurrentPressure") : ("unit.bar", ()),
("*", "getDefaultTemperature") : ("unit.kelvin", ()), ("*", "getDefaultTemperature") : ("unit.kelvin", ()),
("*", "setDefaultTemperature") : (None, ("unit.kelvin",)), ("*", "setDefaultTemperature") : (None, ("unit.kelvin",)),
("*", "getRelativeTemperature") : ("unit.kelvin", ()), ("*", "getRelativeTemperature") : ("unit.kelvin", ()),
...@@ -493,7 +495,6 @@ UNITS = { ...@@ -493,7 +495,6 @@ UNITS = {
("MonteCarloMembraneBarostat", "MonteCarloMembraneBarostat") : (None, ("unit.bar", "unit.bar*unit.nanometer", "unit.kelvin", None, None, None)), ("MonteCarloMembraneBarostat", "MonteCarloMembraneBarostat") : (None, ("unit.bar", "unit.bar*unit.nanometer", "unit.kelvin", None, None, None)),
("MonteCarloMembraneBarostat", "getXYMode") : (None, ()), ("MonteCarloMembraneBarostat", "getXYMode") : (None, ()),
("MonteCarloMembraneBarostat", "getZMode") : (None, ()), ("MonteCarloMembraneBarostat", "getZMode") : (None, ()),
("*", "computeCurrentPressure") : ("unit.bar", ()),
("CustomIntegrator", "CustomIntegrator") : (None, ("unit.picosecond",)), ("CustomIntegrator", "CustomIntegrator") : (None, ("unit.picosecond",)),
("BrownianIntegrator", "BrownianIntegrator") : (None, ("unit.kelvin", "unit.picosecond**-1", "unit.picosecond")), ("BrownianIntegrator", "BrownianIntegrator") : (None, ("unit.kelvin", "unit.picosecond**-1", "unit.picosecond")),
("LangevinIntegrator", "LangevinIntegrator") : (None, ("unit.kelvin", "unit.picosecond**-1", "unit.picosecond")), ("LangevinIntegrator", "LangevinIntegrator") : (None, ("unit.kelvin", "unit.picosecond**-1", "unit.picosecond")),
...@@ -571,4 +572,7 @@ UNITS = { ...@@ -571,4 +572,7 @@ UNITS = {
("ATMForce", "getParticleTransformation") : (None, ()), ("ATMForce", "getParticleTransformation") : (None, ()),
("FixedDisplacement", "getFixedDisplacement1") : ("unit.nanometer", ()), ("FixedDisplacement", "getFixedDisplacement1") : ("unit.nanometer", ()),
("FixedDisplacement", "getFixedDisplacement0") : ("unit.nanometer", ()), ("FixedDisplacement", "getFixedDisplacement0") : ("unit.nanometer", ()),
("PythonForce", "getComputation") : (None, ()),
("PythonForce", "getGlobalParameters") : (None, ()),
("PythonForce", "getPickledFunction") : (None, ()),
} }
...@@ -9,5 +9,6 @@ ...@@ -9,5 +9,6 @@
%include pythonprepend.i %include pythonprepend.i
%include pythonappend.i %include pythonappend.i
%include typemaps.i %include typemaps.i
%include pythonforce.i
%feature("director") OpenMM::MinimizationReporter; %feature("director") OpenMM::MinimizationReporter;
%newobject OpenMM::PythonForce::PythonForce;
%inline %{
#include <iomanip>
namespace OpenMM {
/**
* This is the PythonForceComputation that performs the computation for a PythonForce. It invokes the function
* provided by the user, validates the outputs, and converts them to the required format.
*/
class ComputationWrapper : public PythonForceComputation {
public:
ComputationWrapper(PyObject* computation) : computation(computation) {
}
void compute(const State& state, double& energy, void* forces, bool forcesAreDouble) const {
PyGILState_STATE gstate;
gstate = PyGILState_Ensure();
// Invoke the function.
swig_type_info* info = SWIGTYPE_p_OpenMM__State;
PyObject* wrappedState = SWIG_NewPointerObj((void*) &state, info, 0);
PyObject* result = PyObject_CallFunctionObjArgs(computation, wrappedState, NULL);
if (result == NULL) {
// The function raised an exception. Convert it to an OpenMMException.
#if PY_MAJOR_VERSION == 3 && PY_MINOR_VERSION < 12
PyObject *type;
PyObject *exception;
PyObject *traceback;
PyErr_Fetch(&type, &exception, &traceback);
#else
PyObject *exception = PyErr_GetRaisedException();
#endif
PyObject *message = PyObject_Str(exception);
std::string *ptr;
SWIG_AsPtr_std_string(message, &ptr);
Py_XDECREF(message);
PyGILState_Release(gstate);
throw OpenMMException(*ptr);
}
// Extract the return values.
if (!PyTuple_Check(result) || PyTuple_Size(result) != 2) {
PyGILState_Release(gstate);
throw OpenMMException("PythonForce: Expected two return values");
}
PyObject* pyenergy = Py_StripOpenMMUnits(PyTuple_GetItem(result, 0));
PyObject* pyforces = Py_StripOpenMMUnits(PyTuple_GetItem(result, 1));
energy = PyFloat_AsDouble(pyenergy);
// Copy the forces to the output vector.
if (!PyArray_Check(pyforces) || PyArray_NDIM((PyArrayObject*) pyforces) != 2) {
PyGILState_Release(gstate);
throw OpenMMException("PythonForce: The forces must be returned in a 2-dimensional NumPy array");
}
npy_intp* dims = PyArray_DIMS((PyArrayObject*) pyforces);
int numParticles = state.getPositions().size();
if (dims[0] != numParticles || dims[1] != 3) {
PyGILState_Release(gstate);
throw OpenMMException("PythonForce: The forces must be returned in a NumPy array of shape (# particles, 3)");
}
PyObject* array;
int targetType = (forcesAreDouble ? NPY_DOUBLE : NPY_FLOAT);
if (PyArray_CHKFLAGS((PyArrayObject*) pyforces, NPY_ARRAY_CARRAY_RO) && PyArray_DESCR((PyArrayObject*) pyforces)->type_num == targetType)
array = pyforces;
else
array = PyArray_FromAny(pyforces, PyArray_DescrFromType(targetType), 2, 2, NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_FORCECAST, NULL);
int elementSize = (forcesAreDouble ? sizeof(double) : sizeof(float));
void* data = PyArray_DATA((PyArrayObject*) array);
memcpy(forces, data, 3*elementSize*numParticles);
Py_XDECREF(wrappedState);
Py_XDECREF(result);
Py_XDECREF(pyenergy);
Py_XDECREF(pyforces);
if (array != pyforces)
Py_XDECREF(array);
PyGILState_Release(gstate);
}
private:
PyObject* computation;
};
/**
* Construct a new PythonForce.
*/
PythonForce* _createPythonForce(PyObject* computation, const std::map<std::string, double>& globalParameters={}) {
PythonForce* force = new PythonForce(new ComputationWrapper(computation), globalParameters);
PyObject* pickle = PyImport_ImportModule("pickle");
PyObject* dumps = PyUnicode_FromString("dumps");
PyObject* result = PyObject_CallMethodOneArg(pickle, dumps, computation);
if (result == NULL) {
// It couldn't be pickled. It will still work, but can't be serialized. Clear the error flag.
PyErr_Clear();
}
else {
char* buffer;
Py_ssize_t len;
if (PyBytes_AsStringAndSize(result, &buffer, &len) == 0)
force->setPickledFunction(buffer, len);
}
return force;
}
/**
* This is the serialization proxy used to serialize PythonForce objects.
*/
class PythonForceProxy : public SerializationProxy {
public:
PythonForceProxy() : SerializationProxy("PythonForce") {
}
static std::string hexEncode(const std::vector<char>& input) {
std::stringstream ss;
ss << std::hex << std::setfill('0');
for (unsigned char i : input)
ss << std::setw(2) << static_cast<uint64_t>(i);
return ss.str();
}
static std::vector<char> hexDecode(const std::string& input) {
std::vector<char> res;
res.reserve(input.size() / 2);
for (size_t i = 0; i < input.length(); i += 2) {
std::istringstream iss(input.substr(i, 2));
uint64_t temp;
iss >> std::hex >> temp;
res.push_back(static_cast<unsigned char>(temp));
}
return res;
}
void serialize(const void* object, SerializationNode& node) const {
node.setIntProperty("version", 1);
const PythonForce& force = *reinterpret_cast<const PythonForce*>(object);
if (force.getPickledFunction().size() == 0)
throw OpenMMException("PythonForceProxy: Could not serialize PythonForce because its function could not be pickled.");
node.setStringProperty("function", hexEncode(force.getPickledFunction()));
node.setIntProperty("forceGroup", force.getForceGroup());
node.setBoolProperty("usesPeriodic", force.usesPeriodicBoundaryConditions());
SerializationNode& globalParams = node.createChildNode("GlobalParameters");
for (auto param : force.getGlobalParameters())
globalParams.createChildNode("Parameter").setStringProperty("name", param.first).setDoubleProperty("default", param.second);
}
void* deserialize(const SerializationNode& node) const {
int version = node.getIntProperty("version");
if (version != 1)
throw OpenMMException("Unsupported version number");
std::vector<char> pickledFunction = hexDecode(node.getStringProperty("function"));
PyObject* pickle = PyImport_ImportModule("pickle");
PyObject* loads = PyUnicode_FromString("loads");
PyObject *pythonBytes = PyBytes_FromStringAndSize(pickledFunction.data(), pickledFunction.size());
PyObject *function = PyObject_CallMethodOneArg(pickle, loads, pythonBytes);
Py_XDECREF(pythonBytes);
const SerializationNode& paramsNode = node.getChildNode("GlobalParameters");
std::map<std::string, double> params;
for (auto& parameter : paramsNode.getChildren())
params[parameter.getStringProperty("name")] = parameter.getDoubleProperty("default");
PythonForce* force = _createPythonForce(function, params);
if (node.hasProperty("forceGroup"))
force->setForceGroup(node.getIntProperty("forceGroup", 0));
if (node.hasProperty("usesPeriodic"))
force->setUsesPeriodicBoundaryConditions(node.getBoolProperty("usesPeriodic"));
return force;
}
};
/**
* Register the serialization proxy. This function is invoked automatically when the openmm module is imported.
*/
void registerPythonForceProxy() {
SerializationProxy::registerProxy(typeid(PythonForce), new PythonForceProxy());
}
}
%}
%extend OpenMM::PythonForce {
%feature("docstring") PythonForce "Create a PythonForce.
Parameters
----------
computation : function
A function that performs the computation. It should take a State as its argument
and return two values: the potential energy (a scalar) and the forces (a NumPy array).
globalParameters : dict
Any global parameters the function depends on. Keys are the parameter names, and the
corresponding values are their default values.
"
PythonForce(PyObject* computation, const std::map<std::string, double>& globalParameters={}) {
return _createPythonForce(computation, globalParameters);
}
}
...@@ -18,7 +18,7 @@ PyObject* Vec3_to_PyVec3(const OpenMM::Vec3& v) { ...@@ -18,7 +18,7 @@ PyObject* Vec3_to_PyVec3(const OpenMM::Vec3& v) {
} }
} }
%fragment("Py_StripOpenMMUnits", "header") { %header {
/** /**
* Strip any OpenMM units of an input PyObject. * Strip any OpenMM units of an input PyObject.
...@@ -112,7 +112,7 @@ PyObject* Py_StripOpenMMUnits(PyObject *input) { ...@@ -112,7 +112,7 @@ PyObject* Py_StripOpenMMUnits(PyObject *input) {
} }
%fragment("Py_SequenceToVec3", "header", fragment="Py_StripOpenMMUnits") { %fragment("Py_SequenceToVec3", "header") {
OpenMM::Vec3 Py_SequenceToVec3(PyObject* obj, int& status) { OpenMM::Vec3 Py_SequenceToVec3(PyObject* obj, int& status) {
PyObject* s, *o, *o1; PyObject* s, *o, *o1;
double x[3]; double x[3];
...@@ -154,7 +154,7 @@ OpenMM::Vec3 Py_SequenceToVec3(PyObject* obj, int& status) { ...@@ -154,7 +154,7 @@ OpenMM::Vec3 Py_SequenceToVec3(PyObject* obj, int& status) {
} }
} }
%fragment("Py_SequenceToVecDouble", "header", fragment="Py_StripOpenMMUnits") { %fragment("Py_SequenceToVecDouble", "header") {
int Py_SequenceToVecDouble(PyObject* obj, std::vector<double>& out) { int Py_SequenceToVecDouble(PyObject* obj, std::vector<double>& out) {
PyObject* stripped = Py_StripOpenMMUnits(obj); PyObject* stripped = Py_StripOpenMMUnits(obj);
PyObject* item = NULL; PyObject* item = NULL;
...@@ -396,14 +396,14 @@ int Py_SequenceToVecVecVecDouble(PyObject* obj, std::vector<std::vector<std::vec ...@@ -396,14 +396,14 @@ int Py_SequenceToVecVecVecDouble(PyObject* obj, std::vector<std::vector<std::vec
// ------ typemap for double ---- // ------ typemap for double ----
%typemap(typecheck, precedence=SWIG_TYPECHECK_DOUBLE, fragment="Py_StripOpenMMUnits") double { %typemap(typecheck, precedence=SWIG_TYPECHECK_DOUBLE) double {
double argp = 0; double argp = 0;
PyObject* s = NULL; PyObject* s = NULL;
s = Py_StripOpenMMUnits($input); s = Py_StripOpenMMUnits($input);
$1 = (s != NULL) ? SWIG_IsOK(SWIG_AsVal_double(s, &argp)) : 0; $1 = (s != NULL) ? SWIG_IsOK(SWIG_AsVal_double(s, &argp)) : 0;
Py_DECREF(s); Py_DECREF(s);
} }
%typemap(in, noblock=1, fragment="Py_StripOpenMMUnits") double (double argp = 0, int res = 0, %typemap(in, noblock=1) double (double argp = 0, int res = 0,
PyObject* stripped = NULL) { PyObject* stripped = NULL) {
stripped = Py_StripOpenMMUnits($input); stripped = Py_StripOpenMMUnits($input);
......
import unittest
from openmm import *
from openmm.unit import *
import numpy as np
import copy
def compute(state):
"""This is a computation function used by the test cases."""
pos = state.getPositions(asNumpy=True).value_in_unit(nanometer)
k = state.getParameters()['k']
energy = k*np.sum(pos*pos)
force = -0.5*k*pos
return energy*kilojoules_per_mole, force*kilojoules_per_mole/nanometer
class TestPythonForce(unittest.TestCase):
"""Test the PythonForce class"""
def testComputeForce(self):
"""Test using PythonForce to compute forces."""
system = System()
for i in range(5):
system.addParticle(1.0)
force = PythonForce(compute, {'k':2.5})
system.addForce(force)
positions = np.random.rand(5, 3)
for i in range(Platform.getNumPlatforms()):
integrator = VerletIntegrator(0.001)
try:
context = Context(system, integrator, Platform.getPlatform(i))
except OpenMMException:
if i == 0:
raise
else:
# This happens on CI when no GPU is available.
continue
context.setPositions(positions)
state = context.getState(energy=True, forces=True)
self.assertAlmostEqual(2.5*np.sum(positions*positions), state.getPotentialEnergy().value_in_unit(kilojoules_per_mole), places=5)
self.assertTrue(np.allclose(-1.25*positions, state.getForces(asNumpy=True).value_in_unit(kilojoules_per_mole/nanometer)))
def testExceptions(self):
"""Test that PythonForce handles exceptions correctly."""
def compute2(state):
raise ValueError('This should fail')
system = System()
system.addParticle(1.0)
force = PythonForce(compute2)
system.addForce(force)
positions = np.random.rand(1, 3)
for i in range(Platform.getNumPlatforms()):
integrator = VerletIntegrator(0.001)
try:
context = Context(system, integrator, Platform.getPlatform(i))
except OpenMMException:
if i == 0:
raise
else:
# This happens on CI when no GPU is available.
continue
context.setPositions(positions)
with self.assertRaises(OpenMMException) as cm:
context.getState(energy=True)
self.assertEqual('This should fail', str(cm.exception))
def testSerialize(self):
"""Test that PythonForce can be serialized."""
force1 = PythonForce(compute, {'k':2.5})
force1.setUsesPeriodicBoundaryConditions(True)
# Make a copy by serializing and the deserializing it.
force2 = copy.copy(force1)
# They should be identical.
self.assertEqual(XmlSerializer.serialize(force1), XmlSerializer.serialize(force2))
self.assertEqual(dict(force2.getGlobalParameters()), {'k':2.5})
self.assertTrue(force2.usesPeriodicBoundaryConditions())
# A locally defined function cannot be pickled. We should not be able to serialize a force
# that uses it.
def compute2(state):
return 1.0, np.zeros(len(state.getPositions()), 3)
force3 = PythonForce(compute2)
with self.assertRaises(OpenMMException):
XmlSerializer.serialize(force3)
def testMinimization(self):
"""Test that PythonForce works correctly with the minimizer."""
system = System()
for i in range(5):
system.addParticle(1.0)
force = PythonForce(compute, {'k':2.5})
system.addForce(force)
positions = np.random.rand(5, 3)
integrator = VerletIntegrator(0.001)
context = Context(system, integrator, Platform.getPlatform('Reference'))
context.setPositions(positions)
# The PythonForce and the MinimizationReporter both involve calling back into Python code,
# possibly from different threads. Make sure it doesn't cause any problems.
class Reporter(MinimizationReporter):
count = 0
def report(self, iteration, x, grad, args):
self.count += 1
return False
reporter = Reporter()
LocalEnergyMinimizer.minimize(context, tolerance=1e-3, reporter=reporter)
self.assertTrue(reporter.count > 0)
state = context.getState(energy=True, positions=True)
self.assertAlmostEqual(0.0, state.getPotentialEnergy().value_in_unit(kilojoules_per_mole))
def testMemory(self):
"""Test for memory leaks in the Python/C++ interface."""
try:
import resource
except:
# The resource module is not available on Windows.
return
system = System()
for i in range(1000):
system.addParticle(1.0)
force = PythonForce(compute, {'k':2.5})
system.addForce(force)
positions = np.random.rand(1000, 3)
integrator = VerletIntegrator(0.001)
context = Context(system, integrator, Platform.getPlatform('Reference'))
context.setPositions(positions)
integrator.step(5000)
memory1 = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
integrator.step(5000)
memory2 = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
self.assertTrue(memory2 < 1.05*memory1)
def testDtypes(self):
"""Test returning forces with different types."""
for dtype in [np.float32, np.float64, int]:
def compute2(state):
return 0, np.array([[1,2,3],[4,5,6]], dtype=dtype)
system = System()
system.addParticle(1.0)
system.addParticle(1.0)
force = PythonForce(compute2)
system.addForce(force)
positions = np.random.rand(2, 3)
for i in range(Platform.getNumPlatforms()):
integrator = VerletIntegrator(0.001)
try:
context = Context(system, integrator, Platform.getPlatform(i))
except OpenMMException:
if i == 0:
raise
else:
# This happens on CI when no GPU is available.
continue
context.setPositions(positions)
forces = context.getState(forces=True).getForces().value_in_unit(kilojoules_per_mole/nanometer)
self.assertEqual(Vec3(1,2,3), forces[0])
self.assertEqual(Vec3(4,5,6), forces[1])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment