"vscode:/vscode.git/clone" did not exist on "7705f533b487a2372357c0bc38aecc9e7dcf1a35"
Commit 3d750e85 authored by Peter Eastman's avatar Peter Eastman
Browse files

Modified the initialization order of ForceImpls to prevent them from trying to...

Modified the initialization order of ForceImpls to prevent them from trying to initialize themselves before the Platform had been chosen
parent a5288d55
...@@ -208,6 +208,7 @@ ENDFOREACH(subdir) ...@@ -208,6 +208,7 @@ ENDFOREACH(subdir)
INCLUDE_DIRECTORIES(BEFORE ${PROJECT_SOURCE_DIR}/src) INCLUDE_DIRECTORIES(BEFORE ${PROJECT_SOURCE_DIR}/src)
ADD_LIBRARY(${SHARED_TARGET} SHARED ${SOURCE_FILES} ${SOURCE_INCLUDE_FILES} ${API_ABS_INCLUDE_FILES}) ADD_LIBRARY(${SHARED_TARGET} SHARED ${SOURCE_FILES} ${SOURCE_INCLUDE_FILES} ${API_ABS_INCLUDE_FILES})
ADD_LIBRARY(${STATIC_TARGET} STATIC ${SOURCE_FILES} ${SOURCE_INCLUDE_FILES} ${API_ABS_INCLUDE_FILES})
# #
# Allow automated build and dashboard. # Allow automated build and dashboard.
......
...@@ -72,7 +72,7 @@ public: ...@@ -72,7 +72,7 @@ public:
return defaultFreq; return defaultFreq;
} }
protected: protected:
ForceImpl* createImpl(OpenMMContextImpl& context); ForceImpl* createImpl();
private: private:
double defaultTemp, defaultFreq; double defaultTemp, defaultFreq;
}; };
......
...@@ -62,7 +62,7 @@ protected: ...@@ -62,7 +62,7 @@ protected:
* It should create a new ForceImpl object which can be used by the context for calculating forces. * It should create a new ForceImpl object which can be used by the context for calculating forces.
* The ForceImpl will be deleted automatically when the OpenMMContext is deleted. * The ForceImpl will be deleted automatically when the OpenMMContext is deleted.
*/ */
virtual ForceImpl* createImpl(OpenMMContextImpl& context) = 0; virtual ForceImpl* createImpl() = 0;
}; };
} // namespace OpenMM } // namespace OpenMM
......
...@@ -98,7 +98,7 @@ public: ...@@ -98,7 +98,7 @@ public:
soluteDielectric = dielectric; soluteDielectric = dielectric;
} }
protected: protected:
ForceImpl* createImpl(OpenMMContextImpl& context); ForceImpl* createImpl();
private: private:
class AtomInfo; class AtomInfo;
double solventDielectric, soluteDielectric; double solventDielectric, soluteDielectric;
......
...@@ -209,7 +209,7 @@ public: ...@@ -209,7 +209,7 @@ public:
*/ */
void setRBTorsionParameters(int index, int atom1, int atom2, int atom3, int atom4, double c0, double c1, double c2, double c3, double c4, double c5); void setRBTorsionParameters(int index, int atom1, int atom2, int atom3, int atom4, double c0, double c1, double c2, double c3, double c4, double c5);
protected: protected:
ForceImpl* createImpl(OpenMMContextImpl& context); ForceImpl* createImpl();
private: private:
class AtomInfo; class AtomInfo;
class BondInfo; class BondInfo;
......
...@@ -45,7 +45,8 @@ namespace OpenMM { ...@@ -45,7 +45,8 @@ namespace OpenMM {
class AndersenThermostatImpl : public ForceImpl { class AndersenThermostatImpl : public ForceImpl {
public: public:
AndersenThermostatImpl(AndersenThermostat& owner, OpenMMContextImpl& context); AndersenThermostatImpl(AndersenThermostat& owner);
void initialize(OpenMMContextImpl& context);
AndersenThermostat& getOwner() { AndersenThermostat& getOwner() {
return owner; return owner;
} }
......
...@@ -56,6 +56,11 @@ class ForceImpl { ...@@ -56,6 +56,11 @@ class ForceImpl {
public: public:
virtual ~ForceImpl() { virtual ~ForceImpl() {
} }
/**
* This is called after the ForceImpl is created and before updateContextState(), calcForces(),
* or calcEnergy() is called on it. This allows it to do any necessary initialization.
*/
virtual void initialize(OpenMMContextImpl& context) = 0;
/** /**
* Get the Force object from which this ForceImpl was created. * Get the Force object from which this ForceImpl was created.
*/ */
......
...@@ -45,7 +45,8 @@ namespace OpenMM { ...@@ -45,7 +45,8 @@ namespace OpenMM {
class GBSAOBCForceFieldImpl : public ForceImpl { class GBSAOBCForceFieldImpl : public ForceImpl {
public: public:
GBSAOBCForceFieldImpl(GBSAOBCForceField& owner, OpenMMContextImpl& context); GBSAOBCForceFieldImpl(GBSAOBCForceField& owner);
void initialize(OpenMMContextImpl& context);
GBSAOBCForceField& getOwner() { GBSAOBCForceField& getOwner() {
return owner; return owner;
} }
...@@ -59,10 +60,8 @@ public: ...@@ -59,10 +60,8 @@ public:
} }
std::vector<std::string> getKernelNames(); std::vector<std::string> getKernelNames();
private: private:
void initialize(OpenMMContextImpl& context);
GBSAOBCForceField& owner; GBSAOBCForceField& owner;
Kernel kernel; Kernel kernel;
bool hasInitialized;
}; };
} // namespace OpenMM } // namespace OpenMM
......
...@@ -47,8 +47,9 @@ namespace OpenMM { ...@@ -47,8 +47,9 @@ namespace OpenMM {
class StandardMMForceFieldImpl : public ForceImpl { class StandardMMForceFieldImpl : public ForceImpl {
public: public:
StandardMMForceFieldImpl(StandardMMForceField& owner, OpenMMContextImpl& context); StandardMMForceFieldImpl(StandardMMForceField& owner);
~StandardMMForceFieldImpl(); ~StandardMMForceFieldImpl();
void initialize(OpenMMContextImpl& context);
StandardMMForceField& getOwner() { StandardMMForceField& getOwner() {
return owner; return owner;
} }
......
...@@ -41,6 +41,6 @@ AndersenThermostat::AndersenThermostat(double defaultTemperature, double default ...@@ -41,6 +41,6 @@ AndersenThermostat::AndersenThermostat(double defaultTemperature, double default
defaultTemp(defaultTemperature), defaultFreq(defaultCollisionFrequency) { defaultTemp(defaultTemperature), defaultFreq(defaultCollisionFrequency) {
} }
ForceImpl* AndersenThermostat::createImpl(OpenMMContextImpl& context) { ForceImpl* AndersenThermostat::createImpl() {
return new AndersenThermostatImpl(*this, context); return new AndersenThermostatImpl(*this);
} }
...@@ -39,7 +39,10 @@ ...@@ -39,7 +39,10 @@
using namespace OpenMM; using namespace OpenMM;
using std::vector; using std::vector;
AndersenThermostatImpl::AndersenThermostatImpl(AndersenThermostat& owner, OpenMMContextImpl& context) : owner(owner) { AndersenThermostatImpl::AndersenThermostatImpl(AndersenThermostat& owner) : owner(owner) {
}
void AndersenThermostatImpl::initialize(OpenMMContextImpl& context) {
kernel = context.getPlatform().createKernel(ApplyAndersenThermostatKernel::Name()); kernel = context.getPlatform().createKernel(ApplyAndersenThermostatKernel::Name());
const System& system = context.getSystem(); const System& system = context.getSystem();
vector<double> masses(system.getNumAtoms()); vector<double> masses(system.getNumAtoms());
......
...@@ -51,6 +51,6 @@ void GBSAOBCForceField::setAtomParameters(int index, double charge, double radiu ...@@ -51,6 +51,6 @@ void GBSAOBCForceField::setAtomParameters(int index, double charge, double radiu
atoms[index].scalingFactor = scalingFactor; atoms[index].scalingFactor = scalingFactor;
} }
ForceImpl* GBSAOBCForceField::createImpl(OpenMMContextImpl& context) { ForceImpl* GBSAOBCForceField::createImpl() {
return new GBSAOBCForceFieldImpl(*this, context); return new GBSAOBCForceFieldImpl(*this);
} }
...@@ -37,11 +37,10 @@ ...@@ -37,11 +37,10 @@
using namespace OpenMM; using namespace OpenMM;
using std::vector; using std::vector;
GBSAOBCForceFieldImpl::GBSAOBCForceFieldImpl(GBSAOBCForceField& owner, OpenMMContextImpl& context) : owner(owner), hasInitialized(false) { GBSAOBCForceFieldImpl::GBSAOBCForceFieldImpl(GBSAOBCForceField& owner) : owner(owner) {
} }
void GBSAOBCForceFieldImpl::initialize(OpenMMContextImpl& context) { void GBSAOBCForceFieldImpl::initialize(OpenMMContextImpl& context) {
hasInitialized = true;
kernel = context.getPlatform().createKernel(CalcGBSAOBCForceFieldKernel::Name()); kernel = context.getPlatform().createKernel(CalcGBSAOBCForceFieldKernel::Name());
vector<vector<double> > atomParameters(owner.getNumAtoms()); vector<vector<double> > atomParameters(owner.getNumAtoms());
for (int i = 0; i < owner.getNumAtoms(); ++i) { for (int i = 0; i < owner.getNumAtoms(); ++i) {
...@@ -55,14 +54,10 @@ void GBSAOBCForceFieldImpl::initialize(OpenMMContextImpl& context) { ...@@ -55,14 +54,10 @@ void GBSAOBCForceFieldImpl::initialize(OpenMMContextImpl& context) {
} }
void GBSAOBCForceFieldImpl::calcForces(OpenMMContextImpl& context, Stream& forces) { void GBSAOBCForceFieldImpl::calcForces(OpenMMContextImpl& context, Stream& forces) {
if (!hasInitialized)
initialize(context);
dynamic_cast<CalcGBSAOBCForceFieldKernel&>(kernel.getImpl()).executeForces(context.getPositions(), forces); dynamic_cast<CalcGBSAOBCForceFieldKernel&>(kernel.getImpl()).executeForces(context.getPositions(), forces);
} }
double GBSAOBCForceFieldImpl::calcEnergy(OpenMMContextImpl& context) { double GBSAOBCForceFieldImpl::calcEnergy(OpenMMContextImpl& context) {
if (!hasInitialized)
initialize(context);
return dynamic_cast<CalcGBSAOBCForceFieldKernel&>(kernel.getImpl()).executeEnergy(context.getPositions()); return dynamic_cast<CalcGBSAOBCForceFieldKernel&>(kernel.getImpl()).executeEnergy(context.getPositions());
} }
......
...@@ -49,7 +49,7 @@ OpenMMContextImpl::OpenMMContextImpl(OpenMMContext& owner, System& system, Integ ...@@ -49,7 +49,7 @@ OpenMMContextImpl::OpenMMContextImpl(OpenMMContext& owner, System& system, Integ
vector<string> kernelNames; vector<string> kernelNames;
kernelNames.push_back(CalcKineticEnergyKernel::Name()); kernelNames.push_back(CalcKineticEnergyKernel::Name());
for (int i = 0; i < system.getNumForces(); ++i) { for (int i = 0; i < system.getNumForces(); ++i) {
forceImpls.push_back(system.getForce(i).createImpl(*this)); forceImpls.push_back(system.getForce(i).createImpl());
map<string, double> forceParameters = forceImpls[forceImpls.size()-1]->getDefaultParameters(); map<string, double> forceParameters = forceImpls[forceImpls.size()-1]->getDefaultParameters();
parameters.insert(forceParameters.begin(), forceParameters.end()); parameters.insert(forceParameters.begin(), forceParameters.end());
vector<string> forceKernels = forceImpls[forceImpls.size()-1]->getKernelNames(); vector<string> forceKernels = forceImpls[forceImpls.size()-1]->getKernelNames();
...@@ -58,7 +58,7 @@ OpenMMContextImpl::OpenMMContextImpl(OpenMMContext& owner, System& system, Integ ...@@ -58,7 +58,7 @@ OpenMMContextImpl::OpenMMContextImpl(OpenMMContext& owner, System& system, Integ
vector<string> integratorKernels = integrator.getKernelNames(); vector<string> integratorKernels = integrator.getKernelNames();
kernelNames.insert(kernelNames.begin(), integratorKernels.begin(), integratorKernels.end()); kernelNames.insert(kernelNames.begin(), integratorKernels.begin(), integratorKernels.end());
if (platform == 0) if (platform == 0)
platform = &Platform::findPlatform(kernelNames); this->platform = platform = &Platform::findPlatform(kernelNames);
else if (!platform->supportsKernels(kernelNames)) else if (!platform->supportsKernels(kernelNames))
throw OpenMMException("Specified a Platform for an OpenMMContext which does not support all required kernels"); throw OpenMMException("Specified a Platform for an OpenMMContext which does not support all required kernels");
positions = platform->createStream("atomPositions", system.getNumAtoms(), Stream::Double3); positions = platform->createStream("atomPositions", system.getNumAtoms(), Stream::Double3);
...@@ -71,6 +71,8 @@ OpenMMContextImpl::OpenMMContextImpl(OpenMMContext& owner, System& system, Integ ...@@ -71,6 +71,8 @@ OpenMMContextImpl::OpenMMContextImpl(OpenMMContext& owner, System& system, Integ
for (int i = 0; i < masses.size(); ++i) for (int i = 0; i < masses.size(); ++i)
masses[i] = system.getAtomMass(i); masses[i] = system.getAtomMass(i);
dynamic_cast<CalcKineticEnergyKernel&>(kineticEnergyKernel.getImpl()).initialize(masses); dynamic_cast<CalcKineticEnergyKernel&>(kineticEnergyKernel.getImpl()).initialize(masses);
for (int i = 0; i < forceImpls.size(); ++i)
forceImpls[i]->initialize(*this);
integrator.initialize(*this); integrator.initialize(*this);
} }
...@@ -118,7 +120,9 @@ void OpenMMContextImpl::reinitialize() { ...@@ -118,7 +120,9 @@ void OpenMMContextImpl::reinitialize() {
for (int i = 0; i < (int) forceImpls.size(); ++i) for (int i = 0; i < (int) forceImpls.size(); ++i)
delete forceImpls[i]; delete forceImpls[i];
forceImpls.resize(0); forceImpls.resize(0);
for (int i = 0; i < system.getNumForces(); ++i) for (int i = 0; i < system.getNumForces(); ++i) {
forceImpls.push_back(system.getForce(i).createImpl(*this)); forceImpls.push_back(system.getForce(i).createImpl());
forceImpls[i]->initialize(*this);
}
integrator.initialize(*this); integrator.initialize(*this);
} }
...@@ -128,6 +128,6 @@ void StandardMMForceField::setRBTorsionParameters(int index, int atom1, int atom ...@@ -128,6 +128,6 @@ void StandardMMForceField::setRBTorsionParameters(int index, int atom1, int atom
rbTorsions[index].c[5] = c5; rbTorsions[index].c[5] = c5;
} }
ForceImpl* StandardMMForceField::createImpl(OpenMMContextImpl& context) { ForceImpl* StandardMMForceField::createImpl() {
return new StandardMMForceFieldImpl(*this, context); return new StandardMMForceFieldImpl(*this);
} }
...@@ -38,7 +38,13 @@ using std::pair; ...@@ -38,7 +38,13 @@ using std::pair;
using std::vector; using std::vector;
using std::set; using std::set;
StandardMMForceFieldImpl::StandardMMForceFieldImpl(StandardMMForceField& owner, OpenMMContextImpl& context) : owner(owner) { StandardMMForceFieldImpl::StandardMMForceFieldImpl(StandardMMForceField& owner) : owner(owner) {
}
StandardMMForceFieldImpl::~StandardMMForceFieldImpl() {
}
void StandardMMForceFieldImpl::initialize(OpenMMContextImpl& context) {
kernel = context.getPlatform().createKernel(CalcStandardMMForceFieldKernel::Name()); kernel = context.getPlatform().createKernel(CalcStandardMMForceFieldKernel::Name());
vector<vector<int> > bondIndices(owner.getNumBonds()); vector<vector<int> > bondIndices(owner.getNumBonds());
vector<vector<double> > bondParameters(owner.getNumBonds()); vector<vector<double> > bondParameters(owner.getNumBonds());
...@@ -116,9 +122,6 @@ StandardMMForceFieldImpl::StandardMMForceFieldImpl(StandardMMForceField& owner, ...@@ -116,9 +122,6 @@ StandardMMForceFieldImpl::StandardMMForceFieldImpl(StandardMMForceField& owner,
periodicTorsionIndices, periodicTorsionParameters, rbTorsionIndices, rbTorsionParameters, bonded14Indices, 0.5, 1.0/1.2, exclusions, nonbondedParameters); periodicTorsionIndices, periodicTorsionParameters, rbTorsionIndices, rbTorsionParameters, bonded14Indices, 0.5, 1.0/1.2, exclusions, nonbondedParameters);
} }
StandardMMForceFieldImpl::~StandardMMForceFieldImpl() {
}
void StandardMMForceFieldImpl::calcForces(OpenMMContextImpl& context, Stream& forces) { void StandardMMForceFieldImpl::calcForces(OpenMMContextImpl& context, Stream& forces) {
dynamic_cast<CalcStandardMMForceFieldKernel&>(kernel.getImpl()).executeForces(context.getPositions(), forces); dynamic_cast<CalcStandardMMForceFieldKernel&>(kernel.getImpl()).executeForces(context.getPositions(), forces);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment