Commit a711ce2a authored by Peter Eastman's avatar Peter Eastman
Browse files

Refactored ForceImpl, lots of KernelImpl subclasses, and other related classes...

Refactored ForceImpl, lots of KernelImpl subclasses, and other related classes to avoid redundant calculations when requesting a State with both forces and energies
parent 767ea1bd
......@@ -56,8 +56,7 @@ public:
void updateContextState(ContextImpl& context) {
// This force field doesn't update the state directly.
}
void calcForces(ContextImpl& context);
double calcEnergy(ContextImpl& context);
double calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy);
std::map<std::string, double> getDefaultParameters() {
return std::map<std::string, double>(); // This force field doesn't define any parameters.
}
......
......@@ -63,7 +63,7 @@ vector<string> BrownianIntegrator::getKernelNames() {
void BrownianIntegrator::step(int steps) {
for (int i = 0; i < steps; ++i) {
context->updateContextState();
context->calcForces();
context->calcForcesAndEnergy(true, false);
dynamic_cast<IntegrateBrownianStepKernel&>(kernel.getImpl()).execute(*context, *this);
}
}
......@@ -52,12 +52,8 @@ void CMAPTorsionForceImpl::initialize(ContextImpl& context) {
dynamic_cast<CalcCMAPTorsionForceKernel&>(kernel.getImpl()).initialize(context.getSystem(), owner);
}
void CMAPTorsionForceImpl::calcForces(ContextImpl& context) {
dynamic_cast<CalcCMAPTorsionForceKernel&>(kernel.getImpl()).executeForces(context);
}
double CMAPTorsionForceImpl::calcEnergy(ContextImpl& context) {
return dynamic_cast<CalcCMAPTorsionForceKernel&>(kernel.getImpl()).executeEnergy(context);
double CMAPTorsionForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy) {
return dynamic_cast<CalcCMAPTorsionForceKernel&>(kernel.getImpl()).execute(context, includeForces, includeEnergy);
}
vector<string> CMAPTorsionForceImpl::getKernelNames() {
......
......@@ -82,11 +82,14 @@ State Context::getState(int types) const {
Vec3 periodicBoxSize[3];
impl->getPeriodicBoxVectors(periodicBoxSize[0], periodicBoxSize[1], periodicBoxSize[2]);
state.setPeriodicBoxVectors(periodicBoxSize[0], periodicBoxSize[1], periodicBoxSize[2]);
if (types&State::Energy)
state.setEnergy(impl->calcKineticEnergy(), impl->calcPotentialEnergy());
if (types&State::Forces) {
impl->calcForces();
impl->getForces(state.updForces());
bool includeForces = types&State::Forces;
bool includeEnergy = types&State::Energy;
if (includeForces || includeEnergy) {
double energy = impl->calcForcesAndEnergy(includeForces, includeEnergy);
if (includeEnergy)
state.setEnergy(impl->calcKineticEnergy(), energy);
if (includeForces)
impl->getForces(state.updForces());
}
if (types&State::Parameters) {
for (map<string, double>::const_iterator iter = impl->parameters.begin(); iter != impl->parameters.end(); iter++)
......
......@@ -148,28 +148,20 @@ void ContextImpl::applyConstraints(double tol) {
dynamic_cast<ApplyConstraintsKernel&>(applyConstraintsKernel.getImpl()).apply(*this, tol);
}
void ContextImpl::calcForces() {
double ContextImpl::calcForcesAndEnergy(bool includeForces, bool includeEnergy) {
CalcForcesAndEnergyKernel& kernel = dynamic_cast<CalcForcesAndEnergyKernel&>(initializeForcesKernel.getImpl());
kernel.beginForceComputation(*this);
double energy = 0.0;
kernel.beginComputation(*this, includeForces, includeEnergy);
for (int i = 0; i < (int) forceImpls.size(); ++i)
forceImpls[i]->calcForces(*this);
kernel.finishForceComputation(*this);
energy += forceImpls[i]->calcForcesAndEnergy(*this, includeForces, includeEnergy);
energy += kernel.finishComputation(*this, includeForces, includeEnergy);
return energy;
}
double ContextImpl::calcKineticEnergy() {
return dynamic_cast<CalcKineticEnergyKernel&>(kineticEnergyKernel.getImpl()).execute(*this);
}
double ContextImpl::calcPotentialEnergy() {
CalcForcesAndEnergyKernel& kernel = dynamic_cast<CalcForcesAndEnergyKernel&>(initializeForcesKernel.getImpl());
kernel.beginEnergyComputation(*this);
double energy = 0.0;
for (int i = 0; i < (int) forceImpls.size(); ++i)
energy += forceImpls[i]->calcEnergy(*this);
energy += kernel.finishEnergyComputation(*this);
return energy;
}
void ContextImpl::updateContextState() {
for (int i = 0; i < (int) forceImpls.size(); ++i)
forceImpls[i]->updateContextState(*this);
......
......@@ -88,12 +88,8 @@ void CustomAngleForceImpl::initialize(ContextImpl& context) {
dynamic_cast<CalcCustomAngleForceKernel&>(kernel.getImpl()).initialize(context.getSystem(), owner);
}
void CustomAngleForceImpl::calcForces(ContextImpl& context) {
dynamic_cast<CalcCustomAngleForceKernel&>(kernel.getImpl()).executeForces(context);
}
double CustomAngleForceImpl::calcEnergy(ContextImpl& context) {
return dynamic_cast<CalcCustomAngleForceKernel&>(kernel.getImpl()).executeEnergy(context);
double CustomAngleForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy) {
return dynamic_cast<CalcCustomAngleForceKernel&>(kernel.getImpl()).execute(context, includeForces, includeEnergy);
}
vector<string> CustomAngleForceImpl::getKernelNames() {
......
......@@ -82,12 +82,8 @@ void CustomBondForceImpl::initialize(ContextImpl& context) {
dynamic_cast<CalcCustomBondForceKernel&>(kernel.getImpl()).initialize(context.getSystem(), owner);
}
void CustomBondForceImpl::calcForces(ContextImpl& context) {
dynamic_cast<CalcCustomBondForceKernel&>(kernel.getImpl()).executeForces(context);
}
double CustomBondForceImpl::calcEnergy(ContextImpl& context) {
return dynamic_cast<CalcCustomBondForceKernel&>(kernel.getImpl()).executeEnergy(context);
double CustomBondForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy) {
return dynamic_cast<CalcCustomBondForceKernel&>(kernel.getImpl()).execute(context, includeForces, includeEnergy);
}
vector<string> CustomBondForceImpl::getKernelNames() {
......
......@@ -76,12 +76,8 @@ void CustomExternalForceImpl::initialize(ContextImpl& context) {
dynamic_cast<CalcCustomExternalForceKernel&>(kernel.getImpl()).initialize(context.getSystem(), owner);
}
void CustomExternalForceImpl::calcForces(ContextImpl& context) {
dynamic_cast<CalcCustomExternalForceKernel&>(kernel.getImpl()).executeForces(context);
}
double CustomExternalForceImpl::calcEnergy(ContextImpl& context) {
return dynamic_cast<CalcCustomExternalForceKernel&>(kernel.getImpl()).executeEnergy(context);
double CustomExternalForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy) {
return dynamic_cast<CalcCustomExternalForceKernel&>(kernel.getImpl()).execute(context, includeForces, includeEnergy);
}
vector<string> CustomExternalForceImpl::getKernelNames() {
......
......@@ -105,12 +105,8 @@ void CustomGBForceImpl::initialize(ContextImpl& context) {
dynamic_cast<CalcCustomGBForceKernel&>(kernel.getImpl()).initialize(context.getSystem(), owner);
}
void CustomGBForceImpl::calcForces(ContextImpl& context) {
dynamic_cast<CalcCustomGBForceKernel&>(kernel.getImpl()).executeForces(context);
}
double CustomGBForceImpl::calcEnergy(ContextImpl& context) {
return dynamic_cast<CalcCustomGBForceKernel&>(kernel.getImpl()).executeEnergy(context);
double CustomGBForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy) {
return dynamic_cast<CalcCustomGBForceKernel&>(kernel.getImpl()).execute(context, includeForces, includeEnergy);
}
vector<string> CustomGBForceImpl::getKernelNames() {
......
......@@ -178,12 +178,8 @@ void CustomHbondForceImpl::initialize(ContextImpl& context) {
dynamic_cast<CalcCustomHbondForceKernel&>(kernel.getImpl()).initialize(context.getSystem(), owner);
}
void CustomHbondForceImpl::calcForces(ContextImpl& context) {
dynamic_cast<CalcCustomHbondForceKernel&>(kernel.getImpl()).executeForces(context);
}
double CustomHbondForceImpl::calcEnergy(ContextImpl& context) {
return dynamic_cast<CalcCustomHbondForceKernel&>(kernel.getImpl()).executeEnergy(context);
double CustomHbondForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy) {
return dynamic_cast<CalcCustomHbondForceKernel&>(kernel.getImpl()).execute(context, includeForces, includeEnergy);
}
vector<string> CustomHbondForceImpl::getKernelNames() {
......
......@@ -105,12 +105,8 @@ void CustomNonbondedForceImpl::initialize(ContextImpl& context) {
dynamic_cast<CalcCustomNonbondedForceKernel&>(kernel.getImpl()).initialize(context.getSystem(), owner);
}
void CustomNonbondedForceImpl::calcForces(ContextImpl& context) {
dynamic_cast<CalcCustomNonbondedForceKernel&>(kernel.getImpl()).executeForces(context);
}
double CustomNonbondedForceImpl::calcEnergy(ContextImpl& context) {
return dynamic_cast<CalcCustomNonbondedForceKernel&>(kernel.getImpl()).executeEnergy(context);
double CustomNonbondedForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy) {
return dynamic_cast<CalcCustomNonbondedForceKernel&>(kernel.getImpl()).execute(context, includeForces, includeEnergy);
}
vector<string> CustomNonbondedForceImpl::getKernelNames() {
......
......@@ -94,12 +94,8 @@ void CustomTorsionForceImpl::initialize(ContextImpl& context) {
dynamic_cast<CalcCustomTorsionForceKernel&>(kernel.getImpl()).initialize(context.getSystem(), owner);
}
void CustomTorsionForceImpl::calcForces(ContextImpl& context) {
dynamic_cast<CalcCustomTorsionForceKernel&>(kernel.getImpl()).executeForces(context);
}
double CustomTorsionForceImpl::calcEnergy(ContextImpl& context) {
return dynamic_cast<CalcCustomTorsionForceKernel&>(kernel.getImpl()).executeEnergy(context);
double CustomTorsionForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy) {
return dynamic_cast<CalcCustomTorsionForceKernel&>(kernel.getImpl()).execute(context, includeForces, includeEnergy);
}
vector<string> CustomTorsionForceImpl::getKernelNames() {
......
......@@ -55,12 +55,8 @@ void GBSAOBCForceImpl::initialize(ContextImpl& context) {
dynamic_cast<CalcGBSAOBCForceKernel&>(kernel.getImpl()).initialize(context.getSystem(), owner);
}
void GBSAOBCForceImpl::calcForces(ContextImpl& context) {
dynamic_cast<CalcGBSAOBCForceKernel&>(kernel.getImpl()).executeForces(context);
}
double GBSAOBCForceImpl::calcEnergy(ContextImpl& context) {
return dynamic_cast<CalcGBSAOBCForceKernel&>(kernel.getImpl()).executeEnergy(context);
double GBSAOBCForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy) {
return dynamic_cast<CalcGBSAOBCForceKernel&>(kernel.getImpl()).execute(context, includeForces, includeEnergy);
}
std::vector<std::string> GBSAOBCForceImpl::getKernelNames() {
......
......@@ -239,12 +239,8 @@ void GBVIForceImpl::findScaledRadii( int numberOfParticles, const std::vector<st
}
void GBVIForceImpl::calcForces(ContextImpl& context) {
dynamic_cast<CalcGBVIForceKernel&>(kernel.getImpl()).executeForces(context);
}
double GBVIForceImpl::calcEnergy(ContextImpl& context) {
return dynamic_cast<CalcGBVIForceKernel&>(kernel.getImpl()).executeEnergy(context);
double GBVIForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy) {
return dynamic_cast<CalcGBVIForceKernel&>(kernel.getImpl()).execute(context, includeForces, includeEnergy);
}
std::vector<std::string> GBVIForceImpl::getKernelNames() {
......
......@@ -49,12 +49,8 @@ void HarmonicAngleForceImpl::initialize(ContextImpl& context) {
dynamic_cast<CalcHarmonicAngleForceKernel&>(kernel.getImpl()).initialize(context.getSystem(), owner);
}
void HarmonicAngleForceImpl::calcForces(ContextImpl& context) {
dynamic_cast<CalcHarmonicAngleForceKernel&>(kernel.getImpl()).executeForces(context);
}
double HarmonicAngleForceImpl::calcEnergy(ContextImpl& context) {
return dynamic_cast<CalcHarmonicAngleForceKernel&>(kernel.getImpl()).executeEnergy(context);
double HarmonicAngleForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy) {
return dynamic_cast<CalcHarmonicAngleForceKernel&>(kernel.getImpl()).execute(context, includeForces, includeEnergy);
}
std::vector<std::string> HarmonicAngleForceImpl::getKernelNames() {
......
......@@ -49,12 +49,8 @@ void HarmonicBondForceImpl::initialize(ContextImpl& context) {
dynamic_cast<CalcHarmonicBondForceKernel&>(kernel.getImpl()).initialize(context.getSystem(), owner);
}
void HarmonicBondForceImpl::calcForces(ContextImpl& context) {
dynamic_cast<CalcHarmonicBondForceKernel&>(kernel.getImpl()).executeForces(context);
}
double HarmonicBondForceImpl::calcEnergy(ContextImpl& context) {
return dynamic_cast<CalcHarmonicBondForceKernel&>(kernel.getImpl()).executeEnergy(context);
double HarmonicBondForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy) {
return dynamic_cast<CalcHarmonicBondForceKernel&>(kernel.getImpl()).execute(context, includeForces, includeEnergy);
}
std::vector<std::string> HarmonicBondForceImpl::getKernelNames() {
......
......@@ -63,7 +63,7 @@ vector<string> LangevinIntegrator::getKernelNames() {
void LangevinIntegrator::step(int steps) {
for (int i = 0; i < steps; ++i) {
context->updateContextState();
context->calcForces();
context->calcForcesAndEnergy(true, false);
dynamic_cast<IntegrateLangevinStepKernel&>(kernel.getImpl()).execute(*context, *this);
}
}
......@@ -98,12 +98,8 @@ void NonbondedForceImpl::initialize(ContextImpl& context) {
dynamic_cast<CalcNonbondedForceKernel&>(kernel.getImpl()).initialize(context.getSystem(), owner);
}
void NonbondedForceImpl::calcForces(ContextImpl& context) {
dynamic_cast<CalcNonbondedForceKernel&>(kernel.getImpl()).executeForces(context);
}
double NonbondedForceImpl::calcEnergy(ContextImpl& context) {
return dynamic_cast<CalcNonbondedForceKernel&>(kernel.getImpl()).executeEnergy(context);
double NonbondedForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy) {
return dynamic_cast<CalcNonbondedForceKernel&>(kernel.getImpl()).execute(context, includeForces, includeEnergy);
}
std::vector<std::string> NonbondedForceImpl::getKernelNames() {
......
......@@ -49,12 +49,8 @@ void PeriodicTorsionForceImpl::initialize(ContextImpl& context) {
dynamic_cast<CalcPeriodicTorsionForceKernel&>(kernel.getImpl()).initialize(context.getSystem(), owner);
}
void PeriodicTorsionForceImpl::calcForces(ContextImpl& context) {
dynamic_cast<CalcPeriodicTorsionForceKernel&>(kernel.getImpl()).executeForces(context);
}
double PeriodicTorsionForceImpl::calcEnergy(ContextImpl& context) {
return dynamic_cast<CalcPeriodicTorsionForceKernel&>(kernel.getImpl()).executeEnergy(context);
double PeriodicTorsionForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy) {
return dynamic_cast<CalcPeriodicTorsionForceKernel&>(kernel.getImpl()).execute(context, includeForces, includeEnergy);
}
std::vector<std::string> PeriodicTorsionForceImpl::getKernelNames() {
......
......@@ -49,12 +49,8 @@ void RBTorsionForceImpl::initialize(ContextImpl& context) {
dynamic_cast<CalcRBTorsionForceKernel&>(kernel.getImpl()).initialize(context.getSystem(), owner);
}
void RBTorsionForceImpl::calcForces(ContextImpl& context) {
dynamic_cast<CalcRBTorsionForceKernel&>(kernel.getImpl()).executeForces(context);
}
double RBTorsionForceImpl::calcEnergy(ContextImpl& context) {
return dynamic_cast<CalcRBTorsionForceKernel&>(kernel.getImpl()).executeEnergy(context);
double RBTorsionForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy) {
return dynamic_cast<CalcRBTorsionForceKernel&>(kernel.getImpl()).execute(context, includeForces, includeEnergy);
}
std::vector<std::string> RBTorsionForceImpl::getKernelNames() {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment