Unverified Commit 5d8d5874 authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Avoid multiple forces running on the worker thread at once (#5243)

parent 1c528ca8
...@@ -1994,10 +1994,10 @@ public: ...@@ -1994,10 +1994,10 @@ public:
/** /**
* Initialize the kernel. * Initialize the kernel.
* *
* @param system the System this kernel will be applied to * @param context the ContextImpl this kernel will be applied to
* @param force the CustomCPPForceImpl this kernel will be used for * @param force the CustomCPPForceImpl this kernel will be used for
*/ */
virtual void initialize(const System& system, CustomCPPForceImpl& force) = 0; virtual void initialize(const ContextImpl& context, CustomCPPForceImpl& force) = 0;
/** /**
* Execute the kernel to calculate the forces and/or energy. * Execute the kernel to calculate the forces and/or energy.
* *
...@@ -2022,10 +2022,10 @@ public: ...@@ -2022,10 +2022,10 @@ public:
/** /**
* Initialize the kernel. * Initialize the kernel.
* *
* @param system the System this kernel will be applied to * @param context the ContextImpl this kernel will be applied to
* @param force the PythonForce this kernel will be used for * @param force the PythonForce this kernel will be used for
*/ */
virtual void initialize(const System& system, const PythonForce& force) = 0; virtual void initialize(const ContextImpl& context, const PythonForce& force) = 0;
/** /**
* Execute the kernel to calculate the forces and/or energy. * Execute the kernel to calculate the forces and/or energy.
* *
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* This is part of the OpenMM molecular simulation toolkit. * * This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. * * See https://openmm.org/development. *
* * * *
* Portions copyright (c) 2008-2021 Stanford University and the Authors. * * Portions copyright (c) 2008-2026 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -40,7 +40,7 @@ CustomCPPForceImpl::CustomCPPForceImpl(const Force& owner) { ...@@ -40,7 +40,7 @@ CustomCPPForceImpl::CustomCPPForceImpl(const Force& owner) {
void CustomCPPForceImpl::initialize(ContextImpl& context) { void CustomCPPForceImpl::initialize(ContextImpl& context) {
kernel = context.getPlatform().createKernel(CalcCustomCPPForceKernel::Name(), context); kernel = context.getPlatform().createKernel(CalcCustomCPPForceKernel::Name(), context);
kernel.getAs<CalcCustomCPPForceKernel>().initialize(context.getSystem(), *this); kernel.getAs<CalcCustomCPPForceKernel>().initialize(context, *this);
} }
double CustomCPPForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) { double CustomCPPForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) {
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* This is part of the OpenMM molecular simulation toolkit. * * This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. * * See https://openmm.org/development. *
* * * *
* Portions copyright (c) 2025 Stanford University and the Authors. * * Portions copyright (c) 2025-2026 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -47,7 +47,7 @@ PythonForceImpl::~PythonForceImpl() { ...@@ -47,7 +47,7 @@ PythonForceImpl::~PythonForceImpl() {
void PythonForceImpl::initialize(ContextImpl& context) { void PythonForceImpl::initialize(ContextImpl& context) {
kernel = context.getPlatform().createKernel(CalcPythonForceKernel::Name(), context); kernel = context.getPlatform().createKernel(CalcPythonForceKernel::Name(), context);
kernel.getAs<CalcPythonForceKernel>().initialize(context.getSystem(), owner); kernel.getAs<CalcPythonForceKernel>().initialize(context, owner);
} }
double PythonForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) { double PythonForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) {
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
* This is part of the OpenMM molecular simulation toolkit. * * This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. * * See https://openmm.org/development. *
* * * *
* Portions copyright (c) 2008-2025 Stanford University and the Authors. * * Portions copyright (c) 2008-2026 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -1469,10 +1469,10 @@ public: ...@@ -1469,10 +1469,10 @@ public:
/** /**
* Initialize the kernel. * Initialize the kernel.
* *
* @param system the System this kernel will be applied to * @param context the ContextImpl this kernel will be applied to
* @param force the CustomCPPForceImpl this kernel will be used for * @param force the CustomCPPForceImpl this kernel will be used for
*/ */
void initialize(const System& system, CustomCPPForceImpl& force); void initialize(const ContextImpl& context, CustomCPPForceImpl& force);
/** /**
* Execute the kernel to calculate the forces and/or energy. * Execute the kernel to calculate the forces and/or energy.
* *
...@@ -1507,6 +1507,7 @@ private: ...@@ -1507,6 +1507,7 @@ private:
std::vector<float> floatForces; std::vector<float> floatForces;
int forceGroupFlag; int forceGroupFlag;
double energy; double energy;
bool useWorkerThread;
}; };
/** /**
...@@ -1520,10 +1521,10 @@ public: ...@@ -1520,10 +1521,10 @@ public:
/** /**
* Initialize the kernel. * Initialize the kernel.
* *
* @param system the System this kernel will be applied to * @param context the ContextImpl this kernel will be applied to
* @param force the PythonForce this kernel will be used for * @param force the PythonForce this kernel will be used for
*/ */
void initialize(const System& system, const PythonForce& force); void initialize(const ContextImpl& context, const PythonForce& force);
/** /**
* Execute the kernel to calculate the forces and/or energy. * Execute the kernel to calculate the forces and/or energy.
* *
...@@ -1558,7 +1559,7 @@ private: ...@@ -1558,7 +1559,7 @@ private:
std::vector<double> forcesVec; std::vector<double> forcesVec;
int forceGroupFlag; int forceGroupFlag;
double energy; double energy;
bool usePeriodic; bool usePeriodic, useWorkerThread;
}; };
} // namespace OpenMM } // namespace OpenMM
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
* This is part of the OpenMM molecular simulation toolkit. * * This is part of the OpenMM molecular simulation toolkit. *
* See https://openmm.org/development. * * See https://openmm.org/development. *
* * * *
* Portions copyright (c) 2008-2025 Stanford University and the Authors. * * Portions copyright (c) 2008-2026 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -34,6 +34,7 @@ ...@@ -34,6 +34,7 @@
#include "openmm/internal/CustomCompoundBondForceImpl.h" #include "openmm/internal/CustomCompoundBondForceImpl.h"
#include "openmm/internal/DPDIntegratorUtilities.h" #include "openmm/internal/DPDIntegratorUtilities.h"
#include "openmm/internal/OSRngSeed.h" #include "openmm/internal/OSRngSeed.h"
#include "openmm/internal/PythonForceImpl.h"
#include "openmm/internal/ThreadPool.h" #include "openmm/internal/ThreadPool.h"
#include "openmm/internal/timer.h" #include "openmm/internal/timer.h"
#include "CommonKernelSources.h" #include "CommonKernelSources.h"
...@@ -4688,10 +4689,10 @@ public: ...@@ -4688,10 +4689,10 @@ public:
CommonCalcCustomCPPForceKernel& owner; CommonCalcCustomCPPForceKernel& owner;
}; };
void CommonCalcCustomCPPForceKernel::initialize(const System& system, CustomCPPForceImpl& force) { void CommonCalcCustomCPPForceKernel::initialize(const ContextImpl& context, CustomCPPForceImpl& force) {
ContextSelector selector(cc); ContextSelector selector(cc);
this->force = &force; this->force = &force;
int numParticles = system.getNumParticles(); int numParticles = context.getSystem().getNumParticles();
forcesVec.resize(numParticles); forcesVec.resize(numParticles);
positionsVec.resize(numParticles); positionsVec.resize(numParticles);
floatForces.resize(3*numParticles); floatForces.resize(3*numParticles);
...@@ -4706,14 +4707,18 @@ void CommonCalcCustomCPPForceKernel::initialize(const System& system, CustomCPPF ...@@ -4706,14 +4707,18 @@ void CommonCalcCustomCPPForceKernel::initialize(const System& system, CustomCPPF
addForcesKernel->addArg(cc.getLongForceBuffer()); addForcesKernel->addArg(cc.getLongForceBuffer());
addForcesKernel->addArg(cc.getAtomIndexArray()); addForcesKernel->addArg(cc.getAtomIndexArray());
forceGroupFlag = (1<<force.getOwner().getForceGroup()); forceGroupFlag = (1<<force.getOwner().getForceGroup());
if (cc.getNumContexts() == 1) { useWorkerThread = (cc.getNumContexts() == 1);
for (const ForceImpl* impl : context.getForceImpls())
if (dynamic_cast<const CustomCPPForceImpl*>(impl) != NULL || dynamic_cast<const PythonForceImpl*>(impl) != NULL)
useWorkerThread = false;
if (useWorkerThread) {
cc.addPreComputation(new StartCalculationPreComputation(*this)); cc.addPreComputation(new StartCalculationPreComputation(*this));
cc.addPostComputation(new AddForcesPostComputation(*this)); cc.addPostComputation(new AddForcesPostComputation(*this));
} }
} }
double CommonCalcCustomCPPForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) { double CommonCalcCustomCPPForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
if (cc.getNumContexts() == 1) { if (useWorkerThread) {
// This method does nothing. The actual calculation is started by the pre-computation, continued on // This method does nothing. The actual calculation is started by the pre-computation, continued on
// the worker thread, and finished by the post-computation. // the worker thread, and finished by the post-computation.
...@@ -4765,7 +4770,7 @@ double CommonCalcCustomCPPForceKernel::addForces(bool includeForces, bool includ ...@@ -4765,7 +4770,7 @@ double CommonCalcCustomCPPForceKernel::addForces(bool includeForces, bool includ
// Wait until executeOnWorkerThread() is finished. // Wait until executeOnWorkerThread() is finished.
if (cc.getNumContexts() == 1) if (useWorkerThread)
cc.getWorkThread().flush(); cc.getWorkThread().flush();
// Add in the forces. // Add in the forces.
...@@ -4811,11 +4816,11 @@ public: ...@@ -4811,11 +4816,11 @@ public:
CommonCalcPythonForceKernel& owner; CommonCalcPythonForceKernel& owner;
}; };
void CommonCalcPythonForceKernel::initialize(const System& system, const PythonForce& force) { void CommonCalcPythonForceKernel::initialize(const ContextImpl& context, const PythonForce& force) {
ContextSelector selector(cc); ContextSelector selector(cc);
computation = &force.getComputation(); computation = &force.getComputation();
usePeriodic = force.usesPeriodicBoundaryConditions(); usePeriodic = force.usesPeriodicBoundaryConditions();
int numParticles = system.getNumParticles(); int numParticles = context.getSystem().getNumParticles();
positionsVec.resize(numParticles); positionsVec.resize(numParticles);
forcesVec.resize(3*numParticles); forcesVec.resize(3*numParticles);
int elementSize = (cc.getUseDoublePrecision() ? sizeof(double) : sizeof(float)); int elementSize = (cc.getUseDoublePrecision() ? sizeof(double) : sizeof(float));
...@@ -4829,14 +4834,18 @@ void CommonCalcPythonForceKernel::initialize(const System& system, const PythonF ...@@ -4829,14 +4834,18 @@ void CommonCalcPythonForceKernel::initialize(const System& system, const PythonF
addForcesKernel->addArg(cc.getLongForceBuffer()); addForcesKernel->addArg(cc.getLongForceBuffer());
addForcesKernel->addArg(cc.getAtomIndexArray()); addForcesKernel->addArg(cc.getAtomIndexArray());
forceGroupFlag = (1<<force.getForceGroup()); forceGroupFlag = (1<<force.getForceGroup());
if (cc.getNumContexts() == 1) { useWorkerThread = (cc.getNumContexts() == 1);
for (const ForceImpl* impl : context.getForceImpls())
if (dynamic_cast<const CustomCPPForceImpl*>(impl) != NULL || dynamic_cast<const PythonForceImpl*>(impl) != NULL)
useWorkerThread = false;
if (useWorkerThread) {
cc.addPreComputation(new StartCalculationPreComputation(*this)); cc.addPreComputation(new StartCalculationPreComputation(*this));
cc.addPostComputation(new AddForcesPostComputation(*this)); cc.addPostComputation(new AddForcesPostComputation(*this));
} }
} }
double CommonCalcPythonForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) { double CommonCalcPythonForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
if (cc.getNumContexts() == 1) { if (useWorkerThread) {
// This method does nothing. The actual calculation is started by the pre-computation, continued on // This method does nothing. The actual calculation is started by the pre-computation, continued on
// the worker thread, and finished by the post-computation. // the worker thread, and finished by the post-computation.
...@@ -4857,9 +4866,9 @@ double CommonCalcPythonForceKernel::execute(ContextImpl& context, bool includeFo ...@@ -4857,9 +4866,9 @@ double CommonCalcPythonForceKernel::execute(ContextImpl& context, bool includeFo
void CommonCalcPythonForceKernel::beginComputation(bool includeForces, bool includeEnergy, int groups) { void CommonCalcPythonForceKernel::beginComputation(bool includeForces, bool includeEnergy, int groups) {
if ((groups&forceGroupFlag) == 0) if ((groups&forceGroupFlag) == 0)
return; return;
// The actual force computation will be done on a different thread. // The actual force computation will be done on a different thread.
cc.getWorkThread().addTask(new ExecuteTask(*this, includeForces)); cc.getWorkThread().addTask(new ExecuteTask(*this, includeForces));
} }
...@@ -4886,8 +4895,8 @@ double CommonCalcPythonForceKernel::addForces(bool includeForces, bool includeEn ...@@ -4886,8 +4895,8 @@ double CommonCalcPythonForceKernel::addForces(bool includeForces, bool includeEn
return 0; return 0;
// Wait until executeOnWorkerThread() is finished. // Wait until executeOnWorkerThread() is finished.
if (cc.getNumContexts() == 1) if (useWorkerThread)
cc.getWorkThread().flush(); cc.getWorkThread().flush();
// Add in the forces. // Add in the forces.
...@@ -4896,7 +4905,7 @@ double CommonCalcPythonForceKernel::addForces(bool includeForces, bool includeEn ...@@ -4896,7 +4905,7 @@ double CommonCalcPythonForceKernel::addForces(bool includeForces, bool includeEn
ContextSelector selector(cc); ContextSelector selector(cc);
addForcesKernel->execute(cc.getNumAtoms()); addForcesKernel->execute(cc.getNumAtoms());
} }
// Return the energy. // Return the energy.
return energy; return energy;
......
...@@ -1992,10 +1992,10 @@ public: ...@@ -1992,10 +1992,10 @@ public:
/** /**
* Initialize the kernel. * Initialize the kernel.
* *
* @param system the System this kernel will be applied to * @param context the ContextImpl this kernel will be applied to
* @param force the CustomCPPForceImpl this kernel will be used for * @param force the CustomCPPForceImpl this kernel will be used for
*/ */
void initialize(const System& system, CustomCPPForceImpl& force); void initialize(const ContextImpl& context, CustomCPPForceImpl& force);
/** /**
* Execute the kernel to calculate the forces and/or energy. * Execute the kernel to calculate the forces and/or energy.
* *
...@@ -2020,10 +2020,10 @@ public: ...@@ -2020,10 +2020,10 @@ public:
/** /**
* Initialize the kernel. * Initialize the kernel.
* *
* @param system the System this kernel will be applied to * @param context the ContextImpl this kernel will be applied to
* @param force the PythonForce this kernel will be used for * @param force the PythonForce this kernel will be used for
*/ */
void initialize(const System& system, const PythonForce& force); void initialize(const ContextImpl& context, const PythonForce& force);
/** /**
* Execute the kernel to calculate the forces and/or energy. * Execute the kernel to calculate the forces and/or energy.
* *
......
...@@ -3535,9 +3535,9 @@ void ReferenceCalcATMForceKernel::copyParametersToContext(ContextImpl& context, ...@@ -3535,9 +3535,9 @@ void ReferenceCalcATMForceKernel::copyParametersToContext(ContextImpl& context,
loadParams(numParticles, force); loadParams(numParticles, force);
} }
void ReferenceCalcCustomCPPForceKernel::initialize(const System& system, CustomCPPForceImpl& force) { void ReferenceCalcCustomCPPForceKernel::initialize(const ContextImpl& context, CustomCPPForceImpl& force) {
this->force = &force; this->force = &force;
forces.resize(system.getNumParticles()); forces.resize(context.getSystem().getNumParticles());
} }
double ReferenceCalcCustomCPPForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) { double ReferenceCalcCustomCPPForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
...@@ -3550,9 +3550,9 @@ double ReferenceCalcCustomCPPForceKernel::execute(ContextImpl& context, bool inc ...@@ -3550,9 +3550,9 @@ double ReferenceCalcCustomCPPForceKernel::execute(ContextImpl& context, bool inc
return energy; return energy;
} }
void ReferenceCalcPythonForceKernel::initialize(const System& system, const PythonForce& force) { void ReferenceCalcPythonForceKernel::initialize(const ContextImpl& context, const PythonForce& force) {
computation = &force.getComputation(); computation = &force.getComputation();
forces.resize(system.getNumParticles()); forces.resize(context.getSystem().getNumParticles());
usePeriodic = force.usesPeriodicBoundaryConditions(); usePeriodic = force.usesPeriodicBoundaryConditions();
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment