Unverified Commit 6817921b authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Unified threading for Reference and CPU platforms (#4987)

parent 17f085d5
......@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2013-2022 Stanford University and the Authors. *
* Portions copyright (c) 2013-2025 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -89,7 +89,7 @@ private:
class CpuPlatform::PlatformData {
public:
PlatformData(int numParticles, int numThreads, bool deterministicForces);
PlatformData(int numParticles, ThreadPool& threads, bool deterministicForces);
~PlatformData();
/**
* Request that a neighbor list be built and maintained.
......@@ -107,7 +107,7 @@ public:
int requestPosqIndex();
AlignedArray<float> posq;
std::vector<AlignedArray<float> > threadForce;
ThreadPool threads;
ThreadPool& threads;
bool isPeriodic;
CpuRandom random;
std::map<std::string, std::string> propertyValues;
......
......@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2013-2024 Stanford University and the Authors. *
* Portions copyright (c) 2013-2025 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -112,18 +112,19 @@ bool CpuPlatform::isProcessorSupported() {
}
void CpuPlatform::contextCreated(ContextImpl& context, const map<string, string>& properties) const {
ReferencePlatform::contextCreated(context, properties);
const string& threadsPropValue = (properties.find(CpuThreads()) == properties.end() ?
getPropertyDefaultValue(CpuThreads()) : properties.find(CpuThreads())->second);
map<string, string> refProperties = properties;
refProperties["Threads"] = threadsPropValue;
ReferencePlatform::contextCreated(context, refProperties);
string deterministicForcesValue = (properties.find(CpuDeterministicForces()) == properties.end() ?
getPropertyDefaultValue(CpuDeterministicForces()) : properties.find(CpuDeterministicForces())->second);
int numThreads;
stringstream(threadsPropValue) >> numThreads;
transform(deterministicForcesValue.begin(), deterministicForcesValue.end(), deterministicForcesValue.begin(), ::tolower);
bool deterministicForces = (deterministicForcesValue == "true");
PlatformData* data = new PlatformData(context.getSystem().getNumParticles(), numThreads, deterministicForces);
ReferencePlatform::PlatformData* refData = reinterpret_cast<ReferencePlatform::PlatformData*>(context.getPlatformData());
PlatformData* data = new PlatformData(context.getSystem().getNumParticles(), refData->threads, deterministicForces);
contextData[&context] = data;
ReferenceConstraints& constraints = *(ReferenceConstraints*) reinterpret_cast<ReferencePlatform::PlatformData*>(context.getPlatformData())->constraints;
ReferenceConstraints& constraints = *(ReferenceConstraints*) refData->constraints;
if (constraints.settle != NULL) {
CpuSETTLE* parallelSettle = new CpuSETTLE(context.getSystem(), *(ReferenceSETTLEAlgorithm*) constraints.settle, data->threads);
delete constraints.settle;
......@@ -147,10 +148,10 @@ const CpuPlatform::PlatformData& CpuPlatform::getPlatformData(const ContextImpl&
return *contextData[&context];
}
CpuPlatform::PlatformData::PlatformData(int numParticles, int numThreads, bool deterministicForces) : posq(4*numParticles), threads(numThreads),
CpuPlatform::PlatformData::PlatformData(int numParticles, ThreadPool& threads, bool deterministicForces) : posq(4*numParticles), threads(threads),
deterministicForces(deterministicForces), numParticles(numParticles), neighborList(NULL), cutoff(0.0), paddedCutoff(0.0), anyExclusions(false),
currentPosqIndex(-1), nextPosqIndex(0) {
numThreads = threads.getNumThreads();
int numThreads = threads.getNumThreads();
threadForce.resize(numThreads);
for (int i = 0; i < numThreads; i++)
threadForce[i].resize(4*numParticles);
......
......@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2023 Stanford University and the Authors. *
* Portions copyright (c) 2008-2025 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -34,6 +34,7 @@
#include "openmm/Platform.h"
#include "openmm/System.h"
#include "openmm/internal/ThreadPool.h"
#include "openmm/internal/windowsExport.h"
#include "ReferenceConstraints.h"
#include "ReferenceVirtualSites.h"
......@@ -62,11 +63,12 @@ public:
class OPENMM_EXPORT ReferencePlatform::PlatformData {
public:
PlatformData(const System& system);
PlatformData(const System& system, int numThreads);
~PlatformData();
int numParticles;
long long stepCount;
double time;
ThreadPool threads;
std::vector<Vec3>* positions;
std::vector<Vec3>* velocities;
std::vector<Vec3>* forces;
......
......@@ -131,6 +131,11 @@ static map<string, double>& extractEnergyParameterDerivatives(ContextImpl& conte
return *data->energyParameterDerivatives;
}
static ThreadPool& extractThreadPool(ContextImpl& context) {
ReferencePlatform::PlatformData* data = reinterpret_cast<ReferencePlatform::PlatformData*>(context.getPlatformData());
return data->threads;
}
/**
* Make sure an expression doesn't use any undefined variables.
*/
......@@ -1339,13 +1344,13 @@ double ReferenceCalcCustomNonbondedForceKernel::execute(ContextImpl& context, bo
// Add in the long range correction.
if (!hasInitializedLongRangeCorrection) {
ThreadPool threads;
ThreadPool& threads = extractThreadPool(context);
longRangeCorrectionData = CustomNonbondedForceImpl::prepareLongRangeCorrection(*forceCopy, threads.getNumThreads());
CustomNonbondedForceImpl::calcLongRangeCorrection(*forceCopy, longRangeCorrectionData, context.getOwner(), longRangeCoefficient, longRangeCoefficientDerivs, threads);
hasInitializedLongRangeCorrection = true;
}
else if (globalParamsChanged && forceCopy != NULL) {
ThreadPool threads;
ThreadPool& threads = extractThreadPool(context);
CustomNonbondedForceImpl::calcLongRangeCorrection(*forceCopy, longRangeCorrectionData, context.getOwner(), longRangeCoefficient, longRangeCoefficientDerivs, threads);
}
double volume = boxVectors[0][0]*boxVectors[1][1]*boxVectors[2][2];
......@@ -1372,7 +1377,7 @@ void ReferenceCalcCustomNonbondedForceKernel::copyParametersToContext(ContextImp
// If necessary, recompute the long range correction.
if (forceCopy != NULL) {
ThreadPool threads;
ThreadPool& threads = extractThreadPool(context);
longRangeCorrectionData = CustomNonbondedForceImpl::prepareLongRangeCorrection(force, threads.getNumThreads());
CustomNonbondedForceImpl::calcLongRangeCorrection(force, longRangeCorrectionData, context.getOwner(), longRangeCoefficient, longRangeCoefficientDerivs, threads);
hasInitializedLongRangeCorrection = true;
......
......@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2024 Stanford University and the Authors. *
* Portions copyright (c) 2008-2025 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
......@@ -35,6 +35,7 @@
#include "openmm/internal/ContextImpl.h"
#include "SimTKOpenMMRealType.h"
#include "openmm/Vec3.h"
#include <sstream>
using namespace OpenMM;
using namespace std;
......@@ -89,7 +90,10 @@ bool ReferencePlatform::supportsDoublePrecision() const {
}
void ReferencePlatform::contextCreated(ContextImpl& context, const map<string, string>& properties) const {
context.setPlatformData(new PlatformData(context.getSystem()));
int numThreads = 0;
if (properties.find("Threads") != properties.end())
stringstream(properties.at("Threads")) >> numThreads;
context.setPlatformData(new PlatformData(context.getSystem(), numThreads));
}
void ReferencePlatform::contextDestroyed(ContextImpl& context) const {
......@@ -97,7 +101,8 @@ void ReferencePlatform::contextDestroyed(ContextImpl& context) const {
delete data;
}
ReferencePlatform::PlatformData::PlatformData(const System& system) : time(0.0), stepCount(0), numParticles(system.getNumParticles()) {
ReferencePlatform::PlatformData::PlatformData(const System& system, int numThreads) : time(0.0), stepCount(0),
numParticles(system.getNumParticles()), threads(numThreads) {
positions = new vector<Vec3>(numParticles);
velocities = new vector<Vec3>(numParticles);
forces = new vector<Vec3>(numParticles);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment