Unverified Commit 6817921b authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Unified threading for Reference and CPU platforms (#4987)

parent 17f085d5
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2013-2022 Stanford University and the Authors. * * Portions copyright (c) 2013-2025 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -89,7 +89,7 @@ private: ...@@ -89,7 +89,7 @@ private:
class CpuPlatform::PlatformData { class CpuPlatform::PlatformData {
public: public:
PlatformData(int numParticles, int numThreads, bool deterministicForces); PlatformData(int numParticles, ThreadPool& threads, bool deterministicForces);
~PlatformData(); ~PlatformData();
/** /**
* Request that a neighbor list be built and maintained. * Request that a neighbor list be built and maintained.
...@@ -107,7 +107,7 @@ public: ...@@ -107,7 +107,7 @@ public:
int requestPosqIndex(); int requestPosqIndex();
AlignedArray<float> posq; AlignedArray<float> posq;
std::vector<AlignedArray<float> > threadForce; std::vector<AlignedArray<float> > threadForce;
ThreadPool threads; ThreadPool& threads;
bool isPeriodic; bool isPeriodic;
CpuRandom random; CpuRandom random;
std::map<std::string, std::string> propertyValues; std::map<std::string, std::string> propertyValues;
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2013-2024 Stanford University and the Authors. * * Portions copyright (c) 2013-2025 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -112,18 +112,19 @@ bool CpuPlatform::isProcessorSupported() { ...@@ -112,18 +112,19 @@ bool CpuPlatform::isProcessorSupported() {
} }
void CpuPlatform::contextCreated(ContextImpl& context, const map<string, string>& properties) const { void CpuPlatform::contextCreated(ContextImpl& context, const map<string, string>& properties) const {
ReferencePlatform::contextCreated(context, properties);
const string& threadsPropValue = (properties.find(CpuThreads()) == properties.end() ? const string& threadsPropValue = (properties.find(CpuThreads()) == properties.end() ?
getPropertyDefaultValue(CpuThreads()) : properties.find(CpuThreads())->second); getPropertyDefaultValue(CpuThreads()) : properties.find(CpuThreads())->second);
map<string, string> refProperties = properties;
refProperties["Threads"] = threadsPropValue;
ReferencePlatform::contextCreated(context, refProperties);
string deterministicForcesValue = (properties.find(CpuDeterministicForces()) == properties.end() ? string deterministicForcesValue = (properties.find(CpuDeterministicForces()) == properties.end() ?
getPropertyDefaultValue(CpuDeterministicForces()) : properties.find(CpuDeterministicForces())->second); getPropertyDefaultValue(CpuDeterministicForces()) : properties.find(CpuDeterministicForces())->second);
int numThreads;
stringstream(threadsPropValue) >> numThreads;
transform(deterministicForcesValue.begin(), deterministicForcesValue.end(), deterministicForcesValue.begin(), ::tolower); transform(deterministicForcesValue.begin(), deterministicForcesValue.end(), deterministicForcesValue.begin(), ::tolower);
bool deterministicForces = (deterministicForcesValue == "true"); bool deterministicForces = (deterministicForcesValue == "true");
PlatformData* data = new PlatformData(context.getSystem().getNumParticles(), numThreads, deterministicForces); ReferencePlatform::PlatformData* refData = reinterpret_cast<ReferencePlatform::PlatformData*>(context.getPlatformData());
PlatformData* data = new PlatformData(context.getSystem().getNumParticles(), refData->threads, deterministicForces);
contextData[&context] = data; contextData[&context] = data;
ReferenceConstraints& constraints = *(ReferenceConstraints*) reinterpret_cast<ReferencePlatform::PlatformData*>(context.getPlatformData())->constraints; ReferenceConstraints& constraints = *(ReferenceConstraints*) refData->constraints;
if (constraints.settle != NULL) { if (constraints.settle != NULL) {
CpuSETTLE* parallelSettle = new CpuSETTLE(context.getSystem(), *(ReferenceSETTLEAlgorithm*) constraints.settle, data->threads); CpuSETTLE* parallelSettle = new CpuSETTLE(context.getSystem(), *(ReferenceSETTLEAlgorithm*) constraints.settle, data->threads);
delete constraints.settle; delete constraints.settle;
...@@ -147,10 +148,10 @@ const CpuPlatform::PlatformData& CpuPlatform::getPlatformData(const ContextImpl& ...@@ -147,10 +148,10 @@ const CpuPlatform::PlatformData& CpuPlatform::getPlatformData(const ContextImpl&
return *contextData[&context]; return *contextData[&context];
} }
CpuPlatform::PlatformData::PlatformData(int numParticles, int numThreads, bool deterministicForces) : posq(4*numParticles), threads(numThreads), CpuPlatform::PlatformData::PlatformData(int numParticles, ThreadPool& threads, bool deterministicForces) : posq(4*numParticles), threads(threads),
deterministicForces(deterministicForces), numParticles(numParticles), neighborList(NULL), cutoff(0.0), paddedCutoff(0.0), anyExclusions(false), deterministicForces(deterministicForces), numParticles(numParticles), neighborList(NULL), cutoff(0.0), paddedCutoff(0.0), anyExclusions(false),
currentPosqIndex(-1), nextPosqIndex(0) { currentPosqIndex(-1), nextPosqIndex(0) {
numThreads = threads.getNumThreads(); int numThreads = threads.getNumThreads();
threadForce.resize(numThreads); threadForce.resize(numThreads);
for (int i = 0; i < numThreads; i++) for (int i = 0; i < numThreads; i++)
threadForce[i].resize(4*numParticles); threadForce[i].resize(4*numParticles);
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2023 Stanford University and the Authors. * * Portions copyright (c) 2008-2025 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -34,6 +34,7 @@ ...@@ -34,6 +34,7 @@
#include "openmm/Platform.h" #include "openmm/Platform.h"
#include "openmm/System.h" #include "openmm/System.h"
#include "openmm/internal/ThreadPool.h"
#include "openmm/internal/windowsExport.h" #include "openmm/internal/windowsExport.h"
#include "ReferenceConstraints.h" #include "ReferenceConstraints.h"
#include "ReferenceVirtualSites.h" #include "ReferenceVirtualSites.h"
...@@ -62,11 +63,12 @@ public: ...@@ -62,11 +63,12 @@ public:
class OPENMM_EXPORT ReferencePlatform::PlatformData { class OPENMM_EXPORT ReferencePlatform::PlatformData {
public: public:
PlatformData(const System& system); PlatformData(const System& system, int numThreads);
~PlatformData(); ~PlatformData();
int numParticles; int numParticles;
long long stepCount; long long stepCount;
double time; double time;
ThreadPool threads;
std::vector<Vec3>* positions; std::vector<Vec3>* positions;
std::vector<Vec3>* velocities; std::vector<Vec3>* velocities;
std::vector<Vec3>* forces; std::vector<Vec3>* forces;
......
...@@ -131,6 +131,11 @@ static map<string, double>& extractEnergyParameterDerivatives(ContextImpl& conte ...@@ -131,6 +131,11 @@ static map<string, double>& extractEnergyParameterDerivatives(ContextImpl& conte
return *data->energyParameterDerivatives; return *data->energyParameterDerivatives;
} }
static ThreadPool& extractThreadPool(ContextImpl& context) {
ReferencePlatform::PlatformData* data = reinterpret_cast<ReferencePlatform::PlatformData*>(context.getPlatformData());
return data->threads;
}
/** /**
* Make sure an expression doesn't use any undefined variables. * Make sure an expression doesn't use any undefined variables.
*/ */
...@@ -1339,13 +1344,13 @@ double ReferenceCalcCustomNonbondedForceKernel::execute(ContextImpl& context, bo ...@@ -1339,13 +1344,13 @@ double ReferenceCalcCustomNonbondedForceKernel::execute(ContextImpl& context, bo
// Add in the long range correction. // Add in the long range correction.
if (!hasInitializedLongRangeCorrection) { if (!hasInitializedLongRangeCorrection) {
ThreadPool threads; ThreadPool& threads = extractThreadPool(context);
longRangeCorrectionData = CustomNonbondedForceImpl::prepareLongRangeCorrection(*forceCopy, threads.getNumThreads()); longRangeCorrectionData = CustomNonbondedForceImpl::prepareLongRangeCorrection(*forceCopy, threads.getNumThreads());
CustomNonbondedForceImpl::calcLongRangeCorrection(*forceCopy, longRangeCorrectionData, context.getOwner(), longRangeCoefficient, longRangeCoefficientDerivs, threads); CustomNonbondedForceImpl::calcLongRangeCorrection(*forceCopy, longRangeCorrectionData, context.getOwner(), longRangeCoefficient, longRangeCoefficientDerivs, threads);
hasInitializedLongRangeCorrection = true; hasInitializedLongRangeCorrection = true;
} }
else if (globalParamsChanged && forceCopy != NULL) { else if (globalParamsChanged && forceCopy != NULL) {
ThreadPool threads; ThreadPool& threads = extractThreadPool(context);
CustomNonbondedForceImpl::calcLongRangeCorrection(*forceCopy, longRangeCorrectionData, context.getOwner(), longRangeCoefficient, longRangeCoefficientDerivs, threads); CustomNonbondedForceImpl::calcLongRangeCorrection(*forceCopy, longRangeCorrectionData, context.getOwner(), longRangeCoefficient, longRangeCoefficientDerivs, threads);
} }
double volume = boxVectors[0][0]*boxVectors[1][1]*boxVectors[2][2]; double volume = boxVectors[0][0]*boxVectors[1][1]*boxVectors[2][2];
...@@ -1372,7 +1377,7 @@ void ReferenceCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImp ...@@ -1372,7 +1377,7 @@ void ReferenceCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImp
// If necessary, recompute the long range correction. // If necessary, recompute the long range correction.
if (forceCopy != NULL) { if (forceCopy != NULL) {
ThreadPool threads; ThreadPool& threads = extractThreadPool(context);
longRangeCorrectionData = CustomNonbondedForceImpl::prepareLongRangeCorrection(force, threads.getNumThreads()); longRangeCorrectionData = CustomNonbondedForceImpl::prepareLongRangeCorrection(force, threads.getNumThreads());
CustomNonbondedForceImpl::calcLongRangeCorrection(force, longRangeCorrectionData, context.getOwner(), longRangeCoefficient, longRangeCoefficientDerivs, threads); CustomNonbondedForceImpl::calcLongRangeCorrection(force, longRangeCorrectionData, context.getOwner(), longRangeCoefficient, longRangeCoefficientDerivs, threads);
hasInitializedLongRangeCorrection = true; hasInitializedLongRangeCorrection = true;
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2024 Stanford University and the Authors. * * Portions copyright (c) 2008-2025 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
#include "openmm/internal/ContextImpl.h" #include "openmm/internal/ContextImpl.h"
#include "SimTKOpenMMRealType.h" #include "SimTKOpenMMRealType.h"
#include "openmm/Vec3.h" #include "openmm/Vec3.h"
#include <sstream>
using namespace OpenMM; using namespace OpenMM;
using namespace std; using namespace std;
...@@ -89,7 +90,10 @@ bool ReferencePlatform::supportsDoublePrecision() const { ...@@ -89,7 +90,10 @@ bool ReferencePlatform::supportsDoublePrecision() const {
} }
void ReferencePlatform::contextCreated(ContextImpl& context, const map<string, string>& properties) const { void ReferencePlatform::contextCreated(ContextImpl& context, const map<string, string>& properties) const {
context.setPlatformData(new PlatformData(context.getSystem())); int numThreads = 0;
if (properties.find("Threads") != properties.end())
stringstream(properties.at("Threads")) >> numThreads;
context.setPlatformData(new PlatformData(context.getSystem(), numThreads));
} }
void ReferencePlatform::contextDestroyed(ContextImpl& context) const { void ReferencePlatform::contextDestroyed(ContextImpl& context) const {
...@@ -97,7 +101,8 @@ void ReferencePlatform::contextDestroyed(ContextImpl& context) const { ...@@ -97,7 +101,8 @@ void ReferencePlatform::contextDestroyed(ContextImpl& context) const {
delete data; delete data;
} }
ReferencePlatform::PlatformData::PlatformData(const System& system) : time(0.0), stepCount(0), numParticles(system.getNumParticles()) { ReferencePlatform::PlatformData::PlatformData(const System& system, int numThreads) : time(0.0), stepCount(0),
numParticles(system.getNumParticles()), threads(numThreads) {
positions = new vector<Vec3>(numParticles); positions = new vector<Vec3>(numParticles);
velocities = new vector<Vec3>(numParticles); velocities = new vector<Vec3>(numParticles);
forces = new vector<Vec3>(numParticles); forces = new vector<Vec3>(numParticles);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment