Unverified Commit 9aae4bb3 authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

PythonForce can be restricted to a subset of particles (#5246)

* PythonForce can be restricted to a subset of particles

* Fix exception with CUDA
parent 6717a85c
......@@ -7,7 +7,7 @@
* This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. *
* *
* Portions copyright (c) 2025 Stanford University and the Authors. *
* Portions copyright (c) 2025-2026 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -103,6 +103,15 @@ public:
* box vectors. The positions may also be wrapped into a different periodic box to keep them
* closer to the origin and improve accuracy.
*
* A PythonForce can optionally be applied to only a subset of the particles in a system. To do
* this, call setParticles() on it, providing the indices of the particles to apply it to. The
* computation function should then proceed as if those particles were the entire system.
* state.getPositions() will return a smaller array containing only the positions of those
* particles, and the array of forces should similarly contain only those particles. That is,
* forces[i] should be the force on the i'th particle passed to setParticles(). When applying
* forces to only a small fraction of the particles in a system, this can greatly improve
* performance.
*
* When using XmlSerializer to save a PythonForce, it uses the Python pickle module to save
* the computation function. If it cannot be pickled, you will not be able to serialize the
* PythonForce. Functions defined at the top level of a module can usually be pickled, but local
......@@ -124,9 +133,12 @@ public:
* @param computation an object defining how the forces and energy should be computed
* @param globalParameters any global parameters used by the force. Keys are the parameter
* names, and the corresponding values are their default values.
* @param particles the indices of the particles to use when computing the force. If
* this is empty (the default), all particles in the system will be used.
* @private
*/
explicit PythonForce(PythonForceComputation* computation, const std::map<std::string, double>& globalParameters);
explicit PythonForce(PythonForceComputation* computation, const std::map<std::string, double>& globalParameters,
const std::vector<int>& particles=std::vector<int>());
~PythonForce();
/**
* Get the PythonForceComputation that defines the computation.
......@@ -138,6 +150,18 @@ public:
* corresponding values are their default values.
*/
const std::map<std::string, double>& getGlobalParameters() const;
/**
* Get the indices of the particles to use when computing the force. If this
* is empty, all particles in the system will be used.
*/
const std::vector<int>& getParticles() const {
return particles;
}
/**
* Set the indices of the particles to use when computing the force. If this
* is empty, all particles in the system will be used.
*/
void setParticles(const std::vector<int>& particles);
/**
* Get the pickled representation of the computation function. If it cannot be pickled,
* this will be an empty vector.
......@@ -168,6 +192,7 @@ private:
PythonForceComputation* computation;
std::map<std::string, double> globalParameters;
bool usePeriodic;
std::vector<int> particles;
std::vector<char> pickled;
};
......
......@@ -4,7 +4,7 @@
* This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. *
* *
* Portions copyright (c) 2025 Stanford University and the Authors. *
* Portions copyright (c) 2025-2026 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -33,8 +33,8 @@
using namespace OpenMM;
using namespace std;
PythonForce::PythonForce(PythonForceComputation* computation, const map<string, double>& globalParameters) :
computation(computation), globalParameters(globalParameters), usePeriodic(false) {
PythonForce::PythonForce(PythonForceComputation* computation, const map<string, double>& globalParameters, const vector<int>& particles) :
computation(computation), globalParameters(globalParameters), usePeriodic(false), particles(particles) {
}
PythonForce::~PythonForce() {
......@@ -49,6 +49,10 @@ const map<string, double>& PythonForce::getGlobalParameters() const {
return globalParameters;
}
void PythonForce::setParticles(const std::vector<int>& particles) {
this->particles = particles;
}
bool PythonForce::usesPeriodicBoundaryConditions() const {
return usePeriodic;
}
......
......@@ -1550,14 +1550,18 @@ private:
class ExecuteTask;
class StartCalculationPreComputation;
class AddForcesPostComputation;
class ReorderListener;
void getPositions();
void sortParticles();
OpenMM::ContextImpl& contextImpl;
ComputeContext& cc;
const PythonForceComputation* computation;
ComputeArray forcesArray;
ComputeKernel addForcesKernel;
ComputeArray positionsArray, forcesArray, particlesArray, reorderedParticles;
ComputeKernel copyPositionsKernel, addForcesKernel;
std::vector<Vec3> positionsVec;
std::vector<double> forcesVec;
int forceGroupFlag;
std::vector<int> particles;
int numParticles, forceGroupFlag;
double energy;
bool usePeriodic, useWorkerThread;
};
......
......@@ -4816,23 +4816,58 @@ public:
CommonCalcPythonForceKernel& owner;
};
class CommonCalcPythonForceKernel::ReorderListener : public ComputeContext::ReorderListener {
public:
ReorderListener(CommonCalcPythonForceKernel& owner) : owner(owner) {
}
void execute() {
owner.sortParticles();
}
private:
CommonCalcPythonForceKernel& owner;
};
void CommonCalcPythonForceKernel::initialize(const ContextImpl& context, const PythonForce& force) {
ContextSelector selector(cc);
computation = &force.getComputation();
usePeriodic = force.usesPeriodicBoundaryConditions();
int numParticles = context.getSystem().getNumParticles();
particles = force.getParticles();
numParticles = particles.size();
if (numParticles == 0)
numParticles = context.getSystem().getNumParticles();
positionsVec.resize(numParticles);
forcesVec.resize(3*numParticles);
int elementSize = (cc.getUseDoublePrecision() ? sizeof(double) : sizeof(float));
positionsArray.initialize(cc, 3*numParticles, elementSize, "positions");
forcesArray.initialize(cc, 3*numParticles, elementSize, "forces");
map<string, string> defines;
defines["NUM_ATOMS"] = cc.intToString(numParticles);
defines["PADDED_NUM_ATOMS"] = cc.intToString(cc.getPaddedNumAtoms());
ComputeProgram program = cc.compileProgram(CommonKernelSources::customCppForce, defines);
addForcesKernel = program->createKernel("addForces");
ComputeProgram program = cc.compileProgram(CommonKernelSources::pythonForce, defines);
if (particles.size() > 0) {
particlesArray.initialize<int>(cc, numParticles, "particles");
reorderedParticles.initialize<int>(cc, numParticles, "reorderedParticles");
particlesArray.upload(particles);
reorderedParticles.upload(particles);
cc.addReorderListener(new ReorderListener(*this));
copyPositionsKernel = program->createKernel("copyPositions");
copyPositionsKernel->addArg(cc.getPosq());
copyPositionsKernel->addArg(positionsArray);
copyPositionsKernel->addArg(reorderedParticles);
copyPositionsKernel->addArg(numParticles);
addForcesKernel = program->createKernel("addForcesSubset");
addForcesKernel->addArg(forcesArray);
addForcesKernel->addArg(cc.getLongForceBuffer());
addForcesKernel->addArg(cc.getAtomIndexArray());
addForcesKernel->addArg(reorderedParticles);
addForcesKernel->addArg(numParticles);
}
else {
addForcesKernel = program->createKernel("addForcesAll");
addForcesKernel->addArg(forcesArray);
addForcesKernel->addArg(cc.getLongForceBuffer());
addForcesKernel->addArg(cc.getAtomIndexArray());
}
forceGroupFlag = (1<<force.getForceGroup());
useWorkerThread = (cc.getNumContexts() == 1);
for (const ForceImpl* impl : context.getForceImpls())
......@@ -4858,11 +4893,63 @@ double CommonCalcPythonForceKernel::execute(ContextImpl& context, bool includeFo
if (cc.getContextIndex() != 0)
return 0.0;
contextImpl.getPositions(positionsVec);
getPositions();
executeOnWorkerThread(includeForces);
return addForces(includeForces, includeEnergy, -1);
}
void CommonCalcPythonForceKernel::getPositions() {
// If the NonbondedUtilities uses periodic boundary conditions, the positions might have been
// wrapped to the periodic box. If this force also applies periodic boundary conditions, that's
// alright. Otherwise, we need to move them back.
bool fixPeriodic = usePeriodic || !cc.getNonbondedUtilities().getUsePeriodic();
if (particles.size() == 0) {
// The force applies to the whole system, so we can just use the standard getPositions().
contextImpl.getPositions(positionsVec, fixPeriodic);
}
else {
// Retrieve positions for the subset of particles the force is applied to.
ContextSelector selector(cc);
copyPositionsKernel->execute(numParticles);
if (cc.getUseDoublePrecision()) {
vector<double> pos(3*numParticles);
positionsArray.download(pos);
for (int i = 0; i < numParticles; i++)
positionsVec[i] = Vec3(pos[3*i], pos[3*i+1], pos[3*i+2]);
}
else {
vector<float> pos(3*numParticles);
positionsArray.download(pos);
for (int i = 0; i < numParticles; i++)
positionsVec[i] = Vec3((double) pos[3*i], (double) pos[3*i+1], (double) pos[3*i+2]);
}
if (fixPeriodic) {
Vec3 boxVectors[3];
cc.getPeriodicBoxVectors(boxVectors[0], boxVectors[1], boxVectors[2]);
for (int i = 0; i < numParticles; ++i) {
mm_int4 offset = cc.getPosCellOffsets()[particles[i]];
positionsVec[i] -= boxVectors[0]*offset.x-boxVectors[1]*offset.y-boxVectors[2]*offset.z;
}
}
}
}
void CommonCalcPythonForceKernel::sortParticles() {
// Update the list of particles to account for reordering.
const vector<int>& order = cc.getAtomIndex();
vector<int> inverseOrder(order.size());
for (int i = 0; i < cc.getNumAtoms(); i++)
inverseOrder[order[i]] = i;
vector<int> reordered(particles.size());
for (int i = 0; i < particles.size(); i++)
reordered[i] = inverseOrder[particles[i]];
reorderedParticles.upload(reordered);
}
void CommonCalcPythonForceKernel::beginComputation(bool includeForces, bool includeEnergy, int groups) {
if ((groups&forceGroupFlag) == 0)
return;
......@@ -4873,7 +4960,7 @@ void CommonCalcPythonForceKernel::beginComputation(bool includeForces, bool incl
}
void CommonCalcPythonForceKernel::executeOnWorkerThread(bool includeForces) {
contextImpl.getPositions(positionsVec, usePeriodic || !cc.getNonbondedUtilities().getUsePeriodic());
getPositions();
State::StateBuilder builder(contextImpl.getTime(), contextImpl.getStepCount());
builder.setPositions(positionsVec);
builder.setParameters(contextImpl.getParameters());
......
KERNEL void copyPositions(GLOBAL const real4* RESTRICT posq, GLOBAL real* RESTRICT positions, GLOBAL int* RESTRICT particles, int numParticles) {
for (int i = GLOBAL_ID; i < numParticles; i += GLOBAL_SIZE) {
real4 pos = posq[particles[i]];
positions[3*i] = pos.x;
positions[3*i+1] = pos.y;
positions[3*i+2] = pos.z;
}
}
KERNEL void addForcesAll(GLOBAL const real* RESTRICT forces, GLOBAL mm_long* RESTRICT forceBuffers, GLOBAL int* RESTRICT atomIndex) {
for (int atom = GLOBAL_ID; atom < NUM_ATOMS; atom += GLOBAL_SIZE) {
int index = atomIndex[atom];
forceBuffers[atom] += (mm_long) (forces[3*index]*0x100000000);
forceBuffers[atom+PADDED_NUM_ATOMS] += (mm_long) (forces[3*index+1]*0x100000000);
forceBuffers[atom+2*PADDED_NUM_ATOMS] += (mm_long) (forces[3*index+2]*0x100000000);
}
}
KERNEL void addForcesSubset(GLOBAL const real* RESTRICT forces, GLOBAL mm_long* RESTRICT forceBuffers, GLOBAL int* RESTRICT atomIndex, GLOBAL int* RESTRICT particles, int numParticles) {
for (int i = GLOBAL_ID; i < numParticles; i += GLOBAL_SIZE) {
int index = particles[i];
forceBuffers[index] += (mm_long) (forces[3*i]*0x100000000);
forceBuffers[index+PADDED_NUM_ATOMS] += (mm_long) (forces[3*i+1]*0x100000000);
forceBuffers[index+2*PADDED_NUM_ATOMS] += (mm_long) (forces[3*i+2]*0x100000000);
}
}
......@@ -2035,7 +2035,9 @@ public:
double execute(ContextImpl& context, bool includeForces, bool includeEnergy);
private:
const PythonForceComputation* computation;
std::vector<Vec3> forces;
std::vector<Vec3> positions, forces;
std::vector<int> particles;
int numParticles;
bool usePeriodic;
};
......
......@@ -3552,7 +3552,13 @@ double ReferenceCalcCustomCPPForceKernel::execute(ContextImpl& context, bool inc
void ReferenceCalcPythonForceKernel::initialize(const ContextImpl& context, const PythonForce& force) {
computation = &force.getComputation();
forces.resize(context.getSystem().getNumParticles());
particles = force.getParticles();
numParticles = particles.size();
if (numParticles == 0)
numParticles = context.getSystem().getNumParticles();
else
positions.resize(numParticles);
forces.resize(numParticles);
usePeriodic = force.usesPeriodicBoundaryConditions();
}
......@@ -3560,7 +3566,13 @@ double ReferenceCalcPythonForceKernel::execute(ContextImpl& context, bool includ
vector<Vec3>& posData = extractPositions(context);
vector<Vec3>& forceData = extractForces(context);
State::StateBuilder builder(context.getTime(), context.getStepCount());
if (particles.size() == 0)
builder.setPositions(posData);
else {
for (int i = 0; i < particles.size(); i++)
positions[i] = posData[particles[i]];
builder.setPositions(positions);
}
builder.setParameters(context.getParameters());
if (usePeriodic) {
Vec3 a, b, c;
......@@ -3570,8 +3582,15 @@ double ReferenceCalcPythonForceKernel::execute(ContextImpl& context, bool includ
double energy;
State state = builder.getState();
computation->compute(state, energy, forces.data(), true);
if (includeForces)
if (includeForces) {
if (particles.size() == 0) {
for (int i = 0; i < forces.size(); i++)
forceData[i] += forces[i];
}
else {
for (int i = 0; i < forces.size(); i++)
forceData[particles[i]] += forces[i];
}
}
return energy;
}
......@@ -4,7 +4,7 @@
* This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. *
* *
* Portions copyright (c) 2025 Stanford University and the Authors. *
* Portions copyright (c) 2025-2026 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -29,6 +29,7 @@
#include "openmm/internal/AssertionUtilities.h"
#include "openmm/Context.h"
#include "openmm/NonbondedForce.h"
#include "openmm/PythonForce.h"
#include "openmm/Platform.h"
#include "openmm/VerletIntegrator.h"
......@@ -38,7 +39,7 @@
using namespace OpenMM;
using namespace std;
void testForce() {
void testForce(bool subsetParticles) {
class Computation : public PythonForceComputation {
void compute(const State& state, double& energy, void* forces, bool forcesAreDouble) const {
ASSERT_EQUAL(5.0, state.getParameters().at("a"));
......@@ -63,22 +64,31 @@ void testForce() {
}
};
int numParticles = 5;
int totalParticles = (subsetParticles ? numParticles+10 : numParticles);
System system;
Vec3 a(2, 0, 0);
Vec3 b(0.1, 2, 0);
Vec3 c(0.1, 0.1, 2);
system.setDefaultPeriodicBoxVectors(a, b, c);
NonbondedForce* nonbonded = new NonbondedForce(); // To trigger reordering
nonbonded->setNonbondedMethod(NonbondedForce::PME);
system.addForce(nonbonded);
vector<Vec3> positions;
OpenMM_SFMT::SFMT sfmt;
init_gen_rand(0, sfmt);
for (int i = 0; i < numParticles; i++) {
for (int i = 0; i < totalParticles; i++) {
system.addParticle(1.0);
positions.push_back(Vec3(genrand_real2(sfmt), genrand_real2(sfmt), genrand_real2(sfmt)));
nonbonded->addParticle(0.0, 1.0, 0.0);
}
map<string, double> params;
params["a"] = 5.0;
params["b"] = 10.0;
PythonForce* force = new PythonForce(new Computation(), params);
vector<int> particles;
if (subsetParticles)
for (int i = 0; i < numParticles; i++)
particles.push_back(i+5);
PythonForce* force = new PythonForce(new Computation(), params, particles);
ASSERT(!force->usesPeriodicBoundaryConditions());
force->setUsesPeriodicBoundaryConditions(true);
ASSERT(force->usesPeriodicBoundaryConditions());
......@@ -88,8 +98,17 @@ void testForce() {
context.setPositions(positions);
State state = context.getState(State::Energy | State::Forces);
ASSERT_EQUAL_TOL(25.0, state.getPotentialEnergy(), 1e-6);
if (subsetParticles) {
for (int i : particles)
ASSERT_EQUAL_VEC(2*positions[i], state.getForces()[i], 1e-6)
Vec3 zero;
for (int i = 0; i < 5; i++)
ASSERT_EQUAL_VEC(zero, state.getForces()[i], 1e-6);
}
else {
for (int i = 0; i < numParticles; i++)
ASSERT_EQUAL_VEC(2*positions[i], state.getForces()[i], 1e-6)
}
// Check that force groups are handled correctly.
......@@ -102,7 +121,8 @@ void runPlatformTests();
int main(int argc, char* argv[]) {
try {
initializeTests(argc, argv);
testForce();
testForce(false);
testForce(true);
runPlatformTests();
}
catch(const exception& e) {
......
......@@ -580,5 +580,6 @@ UNITS = {
("FixedDisplacement", "getFixedDisplacement0") : ("unit.nanometer", ()),
("PythonForce", "getComputation") : (None, ()),
("PythonForce", "getGlobalParameters") : (None, ()),
("PythonForce", "getParticles") : (None, ()),
("PythonForce", "getPickledFunction") : (None, ()),
}
......@@ -91,8 +91,8 @@ namespace OpenMM {
/**
* Construct a new PythonForce.
*/
PythonForce* _createPythonForce(PyObject* computation, const std::map<std::string, double>& globalParameters={}) {
PythonForce* force = new PythonForce(new ComputationWrapper(computation), globalParameters);
PythonForce* _createPythonForce(PyObject* computation, const std::map<std::string, double>& globalParameters={}, const std::vector<int>& particles={}) {
PythonForce* force = new PythonForce(new ComputationWrapper(computation), globalParameters, particles);
PyObject* pickle = PyImport_ImportModule("pickle");
PyObject* dumps = PyUnicode_FromString("dumps");
PyObject* result = PyObject_CallMethodOneArg(pickle, dumps, computation);
......@@ -138,7 +138,7 @@ namespace OpenMM {
}
void serialize(const void* object, SerializationNode& node) const {
node.setIntProperty("version", 1);
node.setIntProperty("version", 2);
const PythonForce& force = *reinterpret_cast<const PythonForce*>(object);
if (force.getPickledFunction().size() == 0)
throw OpenMMException("PythonForceProxy: Could not serialize PythonForce because its function could not be pickled.");
......@@ -148,11 +148,14 @@ namespace OpenMM {
SerializationNode& globalParams = node.createChildNode("GlobalParameters");
for (auto param : force.getGlobalParameters())
globalParams.createChildNode("Parameter").setStringProperty("name", param.first).setDoubleProperty("default", param.second);
SerializationNode& particlesNode = node.createChildNode("Particles");
for (int i : force.getParticles())
particlesNode.createChildNode("Particle").setIntProperty("index", i);
}
void* deserialize(const SerializationNode& node) const {
int version = node.getIntProperty("version");
if (version != 1)
if (version < 1 || version > 2)
throw OpenMMException("Unsupported version number");
std::vector<char> pickledFunction = hexDecode(node.getStringProperty("function"));
PyObject* pickle = PyImport_ImportModule("pickle");
......@@ -164,7 +167,11 @@ namespace OpenMM {
std::map<std::string, double> params;
for (auto& parameter : paramsNode.getChildren())
params[parameter.getStringProperty("name")] = parameter.getDoubleProperty("default");
PythonForce* force = _createPythonForce(function, params);
std::vector<int> particles;
if (version > 1)
for (auto& particle : node.getChildNode("Particles").getChildren())
particles.push_back(particle.getIntProperty("index"));
PythonForce* force = _createPythonForce(function, params, particles);
if (node.hasProperty("forceGroup"))
force->setForceGroup(node.getIntProperty("forceGroup", 0));
if (node.hasProperty("usesPeriodic"))
......@@ -195,7 +202,7 @@ globalParameters : dict
Any global parameters the function depends on. Keys are the parameter names, and the
corresponding values are their default values.
"
PythonForce(PyObject* computation, const std::map<std::string, double>& globalParameters={}) {
return _createPythonForce(computation, globalParameters);
PythonForce(PyObject* computation, const std::map<std::string, double>& globalParameters={}, const std::vector<int>& particles={}) {
return _createPythonForce(computation, globalParameters, particles);
}
}
......@@ -38,6 +38,33 @@ class TestPythonForce(unittest.TestCase):
self.assertAlmostEqual(2.5*np.sum(positions*positions), state.getPotentialEnergy().value_in_unit(kilojoules_per_mole), places=5)
self.assertTrue(np.allclose(-1.25*positions, state.getForces(asNumpy=True).value_in_unit(kilojoules_per_mole/nanometer)))
def testParticleSubset(self):
"""Test a PythonForce appled to a subset of particles."""
system = System()
for i in range(10):
system.addParticle(1.0)
force = PythonForce(compute, {'k':2.5})
particles = [1,3,5,7,9]
force.setParticles(particles)
system.addForce(force)
positions = np.random.rand(10, 3)
for i in range(Platform.getNumPlatforms()):
integrator = VerletIntegrator(0.001)
try:
context = Context(system, integrator, Platform.getPlatform(i))
except OpenMMException:
if i == 0:
raise
else:
# This happens on CI when no GPU is available.
continue
context.setPositions(positions)
state = context.getState(energy=True, forces=True)
filtered = np.zeros(positions.shape)
filtered[particles] = positions[particles]
self.assertAlmostEqual(2.5*np.sum(filtered*filtered), state.getPotentialEnergy().value_in_unit(kilojoules_per_mole), places=5)
self.assertTrue(np.allclose(-1.25*filtered, state.getForces(asNumpy=True).value_in_unit(kilojoules_per_mole/nanometer)))
def testExceptions(self):
"""Test that PythonForce handles exceptions correctly."""
def compute2(state):
......@@ -67,6 +94,7 @@ class TestPythonForce(unittest.TestCase):
"""Test that PythonForce can be serialized."""
force1 = PythonForce(compute, {'k':2.5})
force1.setUsesPeriodicBoundaryConditions(True)
force1.setParticles([1,3,5])
# Make a copy by serializing and the deserializing it.
......@@ -76,6 +104,7 @@ class TestPythonForce(unittest.TestCase):
self.assertEqual(XmlSerializer.serialize(force1), XmlSerializer.serialize(force2))
self.assertEqual(dict(force2.getGlobalParameters()), {'k':2.5})
self.assertEqual(force1.getParticles(), force2.getParticles())
self.assertTrue(force2.usesPeriodicBoundaryConditions())
# A locally defined function cannot be pickled. We should not be able to serialize a force
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment