Unverified Commit 9aae4bb3 authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

PythonForce can be restricted to a subset of particles (#5246)

* PythonForce can be restricted to a subset of particles

* Fix exception with CUDA
parent 6717a85c
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
* This is part of the OpenMM molecular simulation toolkit. * * This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. * * See https://openmm.org/development. *
* * * *
* Portions copyright (c) 2025 Stanford University and the Authors. * * Portions copyright (c) 2025-2026 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -103,6 +103,15 @@ public: ...@@ -103,6 +103,15 @@ public:
* box vectors. The positions may also be wrapped into a different periodic box to keep them * box vectors. The positions may also be wrapped into a different periodic box to keep them
* closer to the origin and improve accuracy. * closer to the origin and improve accuracy.
* *
* A PythonForce can optionally be applied to only a subset of the particles in a system. To do
* this, call setParticles() on it, providing the indices of the particles to apply it to. The
* computation function should then proceed as if those particles were the entire system.
* state.getPositions() will return a smaller array containing only the positions of those
* particles, and the array of forces should similarly contain only those particles. That is,
* forces[i] should be the force on the i'th particle passed to setParticles(). When applying
* forces to only a small fraction of the particles in a system, this can greatly improve
* performance.
*
* When using XmlSerializer to save a PythonForce, it uses the Python pickle module to save * When using XmlSerializer to save a PythonForce, it uses the Python pickle module to save
* the computation function. If it cannot be pickled, you will not be able to serialize the * the computation function. If it cannot be pickled, you will not be able to serialize the
* PythonForce. Functions defined at the top level of a module can usually be pickled, but local * PythonForce. Functions defined at the top level of a module can usually be pickled, but local
...@@ -124,9 +133,12 @@ public: ...@@ -124,9 +133,12 @@ public:
* @param computation an object defining how the forces and energy should be computed * @param computation an object defining how the forces and energy should be computed
* @param globalParameters any global parameters used by the force. Keys are the parameter * @param globalParameters any global parameters used by the force. Keys are the parameter
* names, and the corresponding values are their default values. * names, and the corresponding values are their default values.
* @param particles the indices of the particles to use when computing the force. If
* this is empty (the default), all particles in the system will be used.
* @private * @private
*/ */
explicit PythonForce(PythonForceComputation* computation, const std::map<std::string, double>& globalParameters); explicit PythonForce(PythonForceComputation* computation, const std::map<std::string, double>& globalParameters,
const std::vector<int>& particles=std::vector<int>());
~PythonForce(); ~PythonForce();
/** /**
* Get the PythonForceComputation that defines the computation. * Get the PythonForceComputation that defines the computation.
...@@ -138,6 +150,18 @@ public: ...@@ -138,6 +150,18 @@ public:
* corresponding values are their default values. * corresponding values are their default values.
*/ */
const std::map<std::string, double>& getGlobalParameters() const; const std::map<std::string, double>& getGlobalParameters() const;
/**
* Get the indices of the particles to use when computing the force. If this
* is empty, all particles in the system will be used.
*/
const std::vector<int>& getParticles() const {
return particles;
}
/**
* Set the indices of the particles to use when computing the force. If this
* is empty, all particles in the system will be used.
*/
void setParticles(const std::vector<int>& particles);
/** /**
* Get the pickled representation of the computation function. If it cannot be pickled, * Get the pickled representation of the computation function. If it cannot be pickled,
* this will be an empty vector. * this will be an empty vector.
...@@ -168,6 +192,7 @@ private: ...@@ -168,6 +192,7 @@ private:
PythonForceComputation* computation; PythonForceComputation* computation;
std::map<std::string, double> globalParameters; std::map<std::string, double> globalParameters;
bool usePeriodic; bool usePeriodic;
std::vector<int> particles;
std::vector<char> pickled; std::vector<char> pickled;
}; };
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* This is part of the OpenMM molecular simulation toolkit. * * This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. * * See https://openmm.org/development. *
* * * *
* Portions copyright (c) 2025 Stanford University and the Authors. * * Portions copyright (c) 2025-2026 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -33,8 +33,8 @@ ...@@ -33,8 +33,8 @@
using namespace OpenMM; using namespace OpenMM;
using namespace std; using namespace std;
PythonForce::PythonForce(PythonForceComputation* computation, const map<string, double>& globalParameters) : PythonForce::PythonForce(PythonForceComputation* computation, const map<string, double>& globalParameters, const vector<int>& particles) :
computation(computation), globalParameters(globalParameters), usePeriodic(false) { computation(computation), globalParameters(globalParameters), usePeriodic(false), particles(particles) {
} }
PythonForce::~PythonForce() { PythonForce::~PythonForce() {
...@@ -49,6 +49,10 @@ const map<string, double>& PythonForce::getGlobalParameters() const { ...@@ -49,6 +49,10 @@ const map<string, double>& PythonForce::getGlobalParameters() const {
return globalParameters; return globalParameters;
} }
void PythonForce::setParticles(const std::vector<int>& particles) {
this->particles = particles;
}
bool PythonForce::usesPeriodicBoundaryConditions() const { bool PythonForce::usesPeriodicBoundaryConditions() const {
return usePeriodic; return usePeriodic;
} }
......
...@@ -1550,14 +1550,18 @@ private: ...@@ -1550,14 +1550,18 @@ private:
class ExecuteTask; class ExecuteTask;
class StartCalculationPreComputation; class StartCalculationPreComputation;
class AddForcesPostComputation; class AddForcesPostComputation;
class ReorderListener;
void getPositions();
void sortParticles();
OpenMM::ContextImpl& contextImpl; OpenMM::ContextImpl& contextImpl;
ComputeContext& cc; ComputeContext& cc;
const PythonForceComputation* computation; const PythonForceComputation* computation;
ComputeArray forcesArray; ComputeArray positionsArray, forcesArray, particlesArray, reorderedParticles;
ComputeKernel addForcesKernel; ComputeKernel copyPositionsKernel, addForcesKernel;
std::vector<Vec3> positionsVec; std::vector<Vec3> positionsVec;
std::vector<double> forcesVec; std::vector<double> forcesVec;
int forceGroupFlag; std::vector<int> particles;
int numParticles, forceGroupFlag;
double energy; double energy;
bool usePeriodic, useWorkerThread; bool usePeriodic, useWorkerThread;
}; };
......
...@@ -4816,23 +4816,58 @@ public: ...@@ -4816,23 +4816,58 @@ public:
CommonCalcPythonForceKernel& owner; CommonCalcPythonForceKernel& owner;
}; };
class CommonCalcPythonForceKernel::ReorderListener : public ComputeContext::ReorderListener {
public:
ReorderListener(CommonCalcPythonForceKernel& owner) : owner(owner) {
}
void execute() {
owner.sortParticles();
}
private:
CommonCalcPythonForceKernel& owner;
};
void CommonCalcPythonForceKernel::initialize(const ContextImpl& context, const PythonForce& force) { void CommonCalcPythonForceKernel::initialize(const ContextImpl& context, const PythonForce& force) {
ContextSelector selector(cc); ContextSelector selector(cc);
computation = &force.getComputation(); computation = &force.getComputation();
usePeriodic = force.usesPeriodicBoundaryConditions(); usePeriodic = force.usesPeriodicBoundaryConditions();
int numParticles = context.getSystem().getNumParticles(); particles = force.getParticles();
numParticles = particles.size();
if (numParticles == 0)
numParticles = context.getSystem().getNumParticles();
positionsVec.resize(numParticles); positionsVec.resize(numParticles);
forcesVec.resize(3*numParticles); forcesVec.resize(3*numParticles);
int elementSize = (cc.getUseDoublePrecision() ? sizeof(double) : sizeof(float)); int elementSize = (cc.getUseDoublePrecision() ? sizeof(double) : sizeof(float));
positionsArray.initialize(cc, 3*numParticles, elementSize, "positions");
forcesArray.initialize(cc, 3*numParticles, elementSize, "forces"); forcesArray.initialize(cc, 3*numParticles, elementSize, "forces");
map<string, string> defines; map<string, string> defines;
defines["NUM_ATOMS"] = cc.intToString(numParticles); defines["NUM_ATOMS"] = cc.intToString(numParticles);
defines["PADDED_NUM_ATOMS"] = cc.intToString(cc.getPaddedNumAtoms()); defines["PADDED_NUM_ATOMS"] = cc.intToString(cc.getPaddedNumAtoms());
ComputeProgram program = cc.compileProgram(CommonKernelSources::customCppForce, defines); ComputeProgram program = cc.compileProgram(CommonKernelSources::pythonForce, defines);
addForcesKernel = program->createKernel("addForces"); if (particles.size() > 0) {
particlesArray.initialize<int>(cc, numParticles, "particles");
reorderedParticles.initialize<int>(cc, numParticles, "reorderedParticles");
particlesArray.upload(particles);
reorderedParticles.upload(particles);
cc.addReorderListener(new ReorderListener(*this));
copyPositionsKernel = program->createKernel("copyPositions");
copyPositionsKernel->addArg(cc.getPosq());
copyPositionsKernel->addArg(positionsArray);
copyPositionsKernel->addArg(reorderedParticles);
copyPositionsKernel->addArg(numParticles);
addForcesKernel = program->createKernel("addForcesSubset");
addForcesKernel->addArg(forcesArray); addForcesKernel->addArg(forcesArray);
addForcesKernel->addArg(cc.getLongForceBuffer()); addForcesKernel->addArg(cc.getLongForceBuffer());
addForcesKernel->addArg(cc.getAtomIndexArray()); addForcesKernel->addArg(cc.getAtomIndexArray());
addForcesKernel->addArg(reorderedParticles);
addForcesKernel->addArg(numParticles);
}
else {
addForcesKernel = program->createKernel("addForcesAll");
addForcesKernel->addArg(forcesArray);
addForcesKernel->addArg(cc.getLongForceBuffer());
addForcesKernel->addArg(cc.getAtomIndexArray());
}
forceGroupFlag = (1<<force.getForceGroup()); forceGroupFlag = (1<<force.getForceGroup());
useWorkerThread = (cc.getNumContexts() == 1); useWorkerThread = (cc.getNumContexts() == 1);
for (const ForceImpl* impl : context.getForceImpls()) for (const ForceImpl* impl : context.getForceImpls())
...@@ -4858,11 +4893,63 @@ double CommonCalcPythonForceKernel::execute(ContextImpl& context, bool includeFo ...@@ -4858,11 +4893,63 @@ double CommonCalcPythonForceKernel::execute(ContextImpl& context, bool includeFo
if (cc.getContextIndex() != 0) if (cc.getContextIndex() != 0)
return 0.0; return 0.0;
contextImpl.getPositions(positionsVec); getPositions();
executeOnWorkerThread(includeForces); executeOnWorkerThread(includeForces);
return addForces(includeForces, includeEnergy, -1); return addForces(includeForces, includeEnergy, -1);
} }
void CommonCalcPythonForceKernel::getPositions() {
// If the NonbondedUtilities uses periodic boundary conditions, the positions might have been
// wrapped to the periodic box. If this force also applies periodic boundary conditions, that's
// alright. Otherwise, we need to move them back.
bool fixPeriodic = usePeriodic || !cc.getNonbondedUtilities().getUsePeriodic();
if (particles.size() == 0) {
// The force applies to the whole system, so we can just use the standard getPositions().
contextImpl.getPositions(positionsVec, fixPeriodic);
}
else {
// Retrieve positions for the subset of particles the force is applied to.
ContextSelector selector(cc);
copyPositionsKernel->execute(numParticles);
if (cc.getUseDoublePrecision()) {
vector<double> pos(3*numParticles);
positionsArray.download(pos);
for (int i = 0; i < numParticles; i++)
positionsVec[i] = Vec3(pos[3*i], pos[3*i+1], pos[3*i+2]);
}
else {
vector<float> pos(3*numParticles);
positionsArray.download(pos);
for (int i = 0; i < numParticles; i++)
positionsVec[i] = Vec3((double) pos[3*i], (double) pos[3*i+1], (double) pos[3*i+2]);
}
if (fixPeriodic) {
Vec3 boxVectors[3];
cc.getPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
for (int i = 0; i < numParticles; ++i) {
mm_int4 offset = cc.getPosCellOffsets()[particles[i]];
positionsVec[i] -= boxVectors[0]*offset.x-boxVectors[1]*offset.y-boxVectors[2]*offset.z;
}
}
}
}
void CommonCalcPythonForceKernel::sortParticles() {
// Update the list of particles to account for reordering.
const vector<int>& order = cc.getAtomIndex();
vector<int> inverseOrder(order.size());
for (int i = 0; i < cc.getNumAtoms(); i++)
inverseOrder[order[i]] = i;
vector<int> reordered(particles.size());
for (int i = 0; i < particles.size(); i++)
reordered[i] = inverseOrder[particles[i]];
reorderedParticles.upload(reordered);
}
void CommonCalcPythonForceKernel::beginComputation(bool includeForces, bool includeEnergy, int groups) { void CommonCalcPythonForceKernel::beginComputation(bool includeForces, bool includeEnergy, int groups) {
if ((groups&forceGroupFlag) == 0) if ((groups&forceGroupFlag) == 0)
return; return;
...@@ -4873,7 +4960,7 @@ void CommonCalcPythonForceKernel::beginComputation(bool includeForces, bool incl ...@@ -4873,7 +4960,7 @@ void CommonCalcPythonForceKernel::beginComputation(bool includeForces, bool incl
} }
void CommonCalcPythonForceKernel::executeOnWorkerThread(bool includeForces) { void CommonCalcPythonForceKernel::executeOnWorkerThread(bool includeForces) {
contextImpl.getPositions(positionsVec, usePeriodic || !cc.getNonbondedUtilities().getUsePeriodic()); getPositions();
State::StateBuilder builder(contextImpl.getTime(), contextImpl.getStepCount()); State::StateBuilder builder(contextImpl.getTime(), contextImpl.getStepCount());
builder.setPositions(positionsVec); builder.setPositions(positionsVec);
builder.setParameters(contextImpl.getParameters()); builder.setParameters(contextImpl.getParameters());
......
KERNEL void copyPositions(GLOBAL const real4* RESTRICT posq, GLOBAL real* RESTRICT positions, GLOBAL int* RESTRICT particles, int numParticles) {
for (int i = GLOBAL_ID; i < numParticles; i += GLOBAL_SIZE) {
real4 pos = posq[particles[i]];
positions[3*i] = pos.x;
positions[3*i+1] = pos.y;
positions[3*i+2] = pos.z;
}
}
KERNEL void addForcesAll(GLOBAL const real* RESTRICT forces, GLOBAL mm_long* RESTRICT forceBuffers, GLOBAL int* RESTRICT atomIndex) {
for (int atom = GLOBAL_ID; atom < NUM_ATOMS; atom += GLOBAL_SIZE) {
int index = atomIndex[atom];
forceBuffers[atom] += (mm_long) (forces[3*index]*0x100000000);
forceBuffers[atom+PADDED_NUM_ATOMS] += (mm_long) (forces[3*index+1]*0x100000000);
forceBuffers[atom+2*PADDED_NUM_ATOMS] += (mm_long) (forces[3*index+2]*0x100000000);
}
}
KERNEL void addForcesSubset(GLOBAL const real* RESTRICT forces, GLOBAL mm_long* RESTRICT forceBuffers, GLOBAL int* RESTRICT atomIndex, GLOBAL int* RESTRICT particles, int numParticles) {
for (int i = GLOBAL_ID; i < numParticles; i += GLOBAL_SIZE) {
int index = particles[i];
forceBuffers[index] += (mm_long) (forces[3*i]*0x100000000);
forceBuffers[index+PADDED_NUM_ATOMS] += (mm_long) (forces[3*i+1]*0x100000000);
forceBuffers[index+2*PADDED_NUM_ATOMS] += (mm_long) (forces[3*i+2]*0x100000000);
}
}
...@@ -2035,7 +2035,9 @@ public: ...@@ -2035,7 +2035,9 @@ public:
double execute(ContextImpl& context, bool includeForces, bool includeEnergy); double execute(ContextImpl& context, bool includeForces, bool includeEnergy);
private: private:
const PythonForceComputation* computation; const PythonForceComputation* computation;
std::vector<Vec3> forces; std::vector<Vec3> positions, forces;
std::vector<int> particles;
int numParticles;
bool usePeriodic; bool usePeriodic;
}; };
......
...@@ -3552,7 +3552,13 @@ double ReferenceCalcCustomCPPForceKernel::execute(ContextImpl& context, bool inc ...@@ -3552,7 +3552,13 @@ double ReferenceCalcCustomCPPForceKernel::execute(ContextImpl& context, bool inc
void ReferenceCalcPythonForceKernel::initialize(const ContextImpl& context, const PythonForce& force) { void ReferenceCalcPythonForceKernel::initialize(const ContextImpl& context, const PythonForce& force) {
computation = &force.getComputation(); computation = &force.getComputation();
forces.resize(context.getSystem().getNumParticles()); particles = force.getParticles();
numParticles = particles.size();
if (numParticles == 0)
numParticles = context.getSystem().getNumParticles();
else
positions.resize(numParticles);
forces.resize(numParticles);
usePeriodic = force.usesPeriodicBoundaryConditions(); usePeriodic = force.usesPeriodicBoundaryConditions();
} }
...@@ -3560,7 +3566,13 @@ double ReferenceCalcPythonForceKernel::execute(ContextImpl& context, bool includ ...@@ -3560,7 +3566,13 @@ double ReferenceCalcPythonForceKernel::execute(ContextImpl& context, bool includ
vector<Vec3>& posData = extractPositions(context); vector<Vec3>& posData = extractPositions(context);
vector<Vec3>& forceData = extractForces(context); vector<Vec3>& forceData = extractForces(context);
State::StateBuilder builder(context.getTime(), context.getStepCount()); State::StateBuilder builder(context.getTime(), context.getStepCount());
if (particles.size() == 0)
builder.setPositions(posData); builder.setPositions(posData);
else {
for (int i = 0; i < particles.size(); i++)
positions[i] = posData[particles[i]];
builder.setPositions(positions);
}
builder.setParameters(context.getParameters()); builder.setParameters(context.getParameters());
if (usePeriodic) { if (usePeriodic) {
Vec3 a, b, c; Vec3 a, b, c;
...@@ -3570,8 +3582,15 @@ double ReferenceCalcPythonForceKernel::execute(ContextImpl& context, bool includ ...@@ -3570,8 +3582,15 @@ double ReferenceCalcPythonForceKernel::execute(ContextImpl& context, bool includ
double energy; double energy;
State state = builder.getState(); State state = builder.getState();
computation->compute(state, energy, forces.data(), true); computation->compute(state, energy, forces.data(), true);
if (includeForces) if (includeForces) {
if (particles.size() == 0) {
for (int i = 0; i < forces.size(); i++) for (int i = 0; i < forces.size(); i++)
forceData[i] += forces[i]; forceData[i] += forces[i];
}
else {
for (int i = 0; i < forces.size(); i++)
forceData[particles[i]] += forces[i];
}
}
return energy; return energy;
} }
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* This is part of the OpenMM molecular simulation toolkit. * * This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. * * See https://openmm.org/development. *
* * * *
* Portions copyright (c) 2025 Stanford University and the Authors. * * Portions copyright (c) 2025-2026 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include "openmm/internal/AssertionUtilities.h" #include "openmm/internal/AssertionUtilities.h"
#include "openmm/Context.h" #include "openmm/Context.h"
#include "openmm/NonbondedForce.h"
#include "openmm/PythonForce.h" #include "openmm/PythonForce.h"
#include "openmm/Platform.h" #include "openmm/Platform.h"
#include "openmm/VerletIntegrator.h" #include "openmm/VerletIntegrator.h"
...@@ -38,7 +39,7 @@ ...@@ -38,7 +39,7 @@
using namespace OpenMM; using namespace OpenMM;
using namespace std; using namespace std;
void testForce() { void testForce(bool subsetParticles) {
class Computation : public PythonForceComputation { class Computation : public PythonForceComputation {
void compute(const State& state, double& energy, void* forces, bool forcesAreDouble) const { void compute(const State& state, double& energy, void* forces, bool forcesAreDouble) const {
ASSERT_EQUAL(5.0, state.getParameters().at("a")); ASSERT_EQUAL(5.0, state.getParameters().at("a"));
...@@ -63,22 +64,31 @@ void testForce() { ...@@ -63,22 +64,31 @@ void testForce() {
} }
}; };
int numParticles = 5; int numParticles = 5;
int totalParticles = (subsetParticles ? numParticles+10 : numParticles);
System system; System system;
Vec3 a(2, 0, 0); Vec3 a(2, 0, 0);
Vec3 b(0.1, 2, 0); Vec3 b(0.1, 2, 0);
Vec3 c(0.1, 0.1, 2); Vec3 c(0.1, 0.1, 2);
system.setDefaultPeriodicBoxVectors(a, b, c); system.setDefaultPeriodicBoxVectors(a, b, c);
NonbondedForce* nonbonded = new NonbondedForce(); // To trigger reordering
nonbonded->setNonbondedMethod(NonbondedForce::PME);
system.addForce(nonbonded);
vector<Vec3> positions; vector<Vec3> positions;
OpenMM_SFMT::SFMT sfmt; OpenMM_SFMT::SFMT sfmt;
init_gen_rand(0, sfmt); init_gen_rand(0, sfmt);
for (int i = 0; i < numParticles; i++) { for (int i = 0; i < totalParticles; i++) {
system.addParticle(1.0); system.addParticle(1.0);
positions.push_back(Vec3(genrand_real2(sfmt), genrand_real2(sfmt), genrand_real2(sfmt))); positions.push_back(Vec3(genrand_real2(sfmt), genrand_real2(sfmt), genrand_real2(sfmt)));
nonbonded->addParticle(0.0, 1.0, 0.0);
} }
map<string, double> params; map<string, double> params;
params["a"] = 5.0; params["a"] = 5.0;
params["b"] = 10.0; params["b"] = 10.0;
PythonForce* force = new PythonForce(new Computation(), params); vector<int> particles;
if (subsetParticles)
for (int i = 0; i < numParticles; i++)
particles.push_back(i+5);
PythonForce* force = new PythonForce(new Computation(), params, particles);
ASSERT(!force->usesPeriodicBoundaryConditions()); ASSERT(!force->usesPeriodicBoundaryConditions());
force->setUsesPeriodicBoundaryConditions(true); force->setUsesPeriodicBoundaryConditions(true);
ASSERT(force->usesPeriodicBoundaryConditions()); ASSERT(force->usesPeriodicBoundaryConditions());
...@@ -88,8 +98,17 @@ void testForce() { ...@@ -88,8 +98,17 @@ void testForce() {
context.setPositions(positions); context.setPositions(positions);
State state = context.getState(State::Energy | State::Forces); State state = context.getState(State::Energy | State::Forces);
ASSERT_EQUAL_TOL(25.0, state.getPotentialEnergy(), 1e-6); ASSERT_EQUAL_TOL(25.0, state.getPotentialEnergy(), 1e-6);
if (subsetParticles) {
for (int i : particles)
ASSERT_EQUAL_VEC(2*positions[i], state.getForces()[i], 1e-6)
Vec3 zero;
for (int i = 0; i < 5; i++)
ASSERT_EQUAL_VEC(zero, state.getForces()[i], 1e-6);
}
else {
for (int i = 0; i < numParticles; i++) for (int i = 0; i < numParticles; i++)
ASSERT_EQUAL_VEC(2*positions[i], state.getForces()[i], 1e-6) ASSERT_EQUAL_VEC(2*positions[i], state.getForces()[i], 1e-6)
}
// Check that force groups are handled correctly. // Check that force groups are handled correctly.
...@@ -102,7 +121,8 @@ void runPlatformTests(); ...@@ -102,7 +121,8 @@ void runPlatformTests();
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
try { try {
initializeTests(argc, argv); initializeTests(argc, argv);
testForce(); testForce(false);
testForce(true);
runPlatformTests(); runPlatformTests();
} }
catch(const exception& e) { catch(const exception& e) {
......
...@@ -580,5 +580,6 @@ UNITS = { ...@@ -580,5 +580,6 @@ UNITS = {
("FixedDisplacement", "getFixedDisplacement0") : ("unit.nanometer", ()), ("FixedDisplacement", "getFixedDisplacement0") : ("unit.nanometer", ()),
("PythonForce", "getComputation") : (None, ()), ("PythonForce", "getComputation") : (None, ()),
("PythonForce", "getGlobalParameters") : (None, ()), ("PythonForce", "getGlobalParameters") : (None, ()),
("PythonForce", "getParticles") : (None, ()),
("PythonForce", "getPickledFunction") : (None, ()), ("PythonForce", "getPickledFunction") : (None, ()),
} }
...@@ -91,8 +91,8 @@ namespace OpenMM { ...@@ -91,8 +91,8 @@ namespace OpenMM {
/** /**
* Construct a new PythonForce. * Construct a new PythonForce.
*/ */
PythonForce* _createPythonForce(PyObject* computation, const std::map<std::string, double>& globalParameters={}) { PythonForce* _createPythonForce(PyObject* computation, const std::map<std::string, double>& globalParameters={}, const std::vector<int>& particles={}) {
PythonForce* force = new PythonForce(new ComputationWrapper(computation), globalParameters); PythonForce* force = new PythonForce(new ComputationWrapper(computation), globalParameters, particles);
PyObject* pickle = PyImport_ImportModule("pickle"); PyObject* pickle = PyImport_ImportModule("pickle");
PyObject* dumps = PyUnicode_FromString("dumps"); PyObject* dumps = PyUnicode_FromString("dumps");
PyObject* result = PyObject_CallMethodOneArg(pickle, dumps, computation); PyObject* result = PyObject_CallMethodOneArg(pickle, dumps, computation);
...@@ -138,7 +138,7 @@ namespace OpenMM { ...@@ -138,7 +138,7 @@ namespace OpenMM {
} }
void serialize(const void* object, SerializationNode& node) const { void serialize(const void* object, SerializationNode& node) const {
node.setIntProperty("version", 1); node.setIntProperty("version", 2);
const PythonForce& force = *reinterpret_cast<const PythonForce*>(object); const PythonForce& force = *reinterpret_cast<const PythonForce*>(object);
if (force.getPickledFunction().size() == 0) if (force.getPickledFunction().size() == 0)
throw OpenMMException("PythonForceProxy: Could not serialize PythonForce because its function could not be pickled."); throw OpenMMException("PythonForceProxy: Could not serialize PythonForce because its function could not be pickled.");
...@@ -148,11 +148,14 @@ namespace OpenMM { ...@@ -148,11 +148,14 @@ namespace OpenMM {
SerializationNode& globalParams = node.createChildNode("GlobalParameters"); SerializationNode& globalParams = node.createChildNode("GlobalParameters");
for (auto param : force.getGlobalParameters()) for (auto param : force.getGlobalParameters())
globalParams.createChildNode("Parameter").setStringProperty("name", param.first).setDoubleProperty("default", param.second); globalParams.createChildNode("Parameter").setStringProperty("name", param.first).setDoubleProperty("default", param.second);
SerializationNode& particlesNode = node.createChildNode("Particles");
for (int i : force.getParticles())
particlesNode.createChildNode("Particle").setIntProperty("index", i);
} }
void* deserialize(const SerializationNode& node) const { void* deserialize(const SerializationNode& node) const {
int version = node.getIntProperty("version"); int version = node.getIntProperty("version");
if (version != 1) if (version < 1 || version > 2)
throw OpenMMException("Unsupported version number"); throw OpenMMException("Unsupported version number");
std::vector<char> pickledFunction = hexDecode(node.getStringProperty("function")); std::vector<char> pickledFunction = hexDecode(node.getStringProperty("function"));
PyObject* pickle = PyImport_ImportModule("pickle"); PyObject* pickle = PyImport_ImportModule("pickle");
...@@ -164,7 +167,11 @@ namespace OpenMM { ...@@ -164,7 +167,11 @@ namespace OpenMM {
std::map<std::string, double> params; std::map<std::string, double> params;
for (auto& parameter : paramsNode.getChildren()) for (auto& parameter : paramsNode.getChildren())
params[parameter.getStringProperty("name")] = parameter.getDoubleProperty("default"); params[parameter.getStringProperty("name")] = parameter.getDoubleProperty("default");
PythonForce* force = _createPythonForce(function, params); std::vector<int> particles;
if (version > 1)
for (auto& particle : node.getChildNode("Particles").getChildren())
particles.push_back(particle.getIntProperty("index"));
PythonForce* force = _createPythonForce(function, params, particles);
if (node.hasProperty("forceGroup")) if (node.hasProperty("forceGroup"))
force->setForceGroup(node.getIntProperty("forceGroup", 0)); force->setForceGroup(node.getIntProperty("forceGroup", 0));
if (node.hasProperty("usesPeriodic")) if (node.hasProperty("usesPeriodic"))
...@@ -195,7 +202,7 @@ globalParameters : dict ...@@ -195,7 +202,7 @@ globalParameters : dict
Any global parameters the function depends on. Keys are the parameter names, and the Any global parameters the function depends on. Keys are the parameter names, and the
corresponding values are their default values. corresponding values are their default values.
" "
PythonForce(PyObject* computation, const std::map<std::string, double>& globalParameters={}) { PythonForce(PyObject* computation, const std::map<std::string, double>& globalParameters={}, const std::vector<int>& particles={}) {
return _createPythonForce(computation, globalParameters); return _createPythonForce(computation, globalParameters, particles);
} }
} }
...@@ -38,6 +38,33 @@ class TestPythonForce(unittest.TestCase): ...@@ -38,6 +38,33 @@ class TestPythonForce(unittest.TestCase):
self.assertAlmostEqual(2.5*np.sum(positions*positions), state.getPotentialEnergy().value_in_unit(kilojoules_per_mole), places=5) self.assertAlmostEqual(2.5*np.sum(positions*positions), state.getPotentialEnergy().value_in_unit(kilojoules_per_mole), places=5)
self.assertTrue(np.allclose(-1.25*positions, state.getForces(asNumpy=True).value_in_unit(kilojoules_per_mole/nanometer))) self.assertTrue(np.allclose(-1.25*positions, state.getForces(asNumpy=True).value_in_unit(kilojoules_per_mole/nanometer)))
def testParticleSubset(self):
"""Test a PythonForce appled to a subset of particles."""
system = System()
for i in range(10):
system.addParticle(1.0)
force = PythonForce(compute, {'k':2.5})
particles = [1,3,5,7,9]
force.setParticles(particles)
system.addForce(force)
positions = np.random.rand(10, 3)
for i in range(Platform.getNumPlatforms()):
integrator = VerletIntegrator(0.001)
try:
context = Context(system, integrator, Platform.getPlatform(i))
except OpenMMException:
if i == 0:
raise
else:
# This happens on CI when no GPU is available.
continue
context.setPositions(positions)
state = context.getState(energy=True, forces=True)
filtered = np.zeros(positions.shape)
filtered[particles] = positions[particles]
self.assertAlmostEqual(2.5*np.sum(filtered*filtered), state.getPotentialEnergy().value_in_unit(kilojoules_per_mole), places=5)
self.assertTrue(np.allclose(-1.25*filtered, state.getForces(asNumpy=True).value_in_unit(kilojoules_per_mole/nanometer)))
def testExceptions(self): def testExceptions(self):
"""Test that PythonForce handles exceptions correctly.""" """Test that PythonForce handles exceptions correctly."""
def compute2(state): def compute2(state):
...@@ -67,6 +94,7 @@ class TestPythonForce(unittest.TestCase): ...@@ -67,6 +94,7 @@ class TestPythonForce(unittest.TestCase):
"""Test that PythonForce can be serialized.""" """Test that PythonForce can be serialized."""
force1 = PythonForce(compute, {'k':2.5}) force1 = PythonForce(compute, {'k':2.5})
force1.setUsesPeriodicBoundaryConditions(True) force1.setUsesPeriodicBoundaryConditions(True)
force1.setParticles([1,3,5])
# Make a copy by serializing and the deserializing it. # Make a copy by serializing and the deserializing it.
...@@ -76,6 +104,7 @@ class TestPythonForce(unittest.TestCase): ...@@ -76,6 +104,7 @@ class TestPythonForce(unittest.TestCase):
self.assertEqual(XmlSerializer.serialize(force1), XmlSerializer.serialize(force2)) self.assertEqual(XmlSerializer.serialize(force1), XmlSerializer.serialize(force2))
self.assertEqual(dict(force2.getGlobalParameters()), {'k':2.5}) self.assertEqual(dict(force2.getGlobalParameters()), {'k':2.5})
self.assertEqual(force1.getParticles(), force2.getParticles())
self.assertTrue(force2.usesPeriodicBoundaryConditions()) self.assertTrue(force2.usesPeriodicBoundaryConditions())
# A locally defined function cannot be pickled. We should not be able to serialize a force # A locally defined function cannot be pickled. We should not be able to serialize a force
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment