"vscode:/vscode.git/clone" did not exist on "ae2a11aa37979ee7751fd9facb71731cdf73708e"
Commit b4044e73 authored by Peter Eastman's avatar Peter Eastman
Browse files

Created templatized interface for Kernel, which made a lot of code that invokes kernels cleaner.

parent fcdba25c
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008 Stanford University and the Authors. * * Portions copyright (c) 2008-2012 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -79,6 +79,20 @@ public: ...@@ -79,6 +79,20 @@ public:
* Get the object which implements this Kernel. * Get the object which implements this Kernel.
*/ */
KernelImpl& getImpl(); KernelImpl& getImpl();
/**
* Get a reference to the object which implements this Kernel, casting it to the specified type.
*/
template <class T>
const T& getAs() const {
return dynamic_cast<T&>(*impl);
}
/**
* Get a reference to the object which implements this Kernel, casting it to the specified type.
*/
template <class T>
T& getAs() {
return dynamic_cast<T&>(*impl);
}
private: private:
KernelImpl* impl; KernelImpl* impl;
}; };
......
...@@ -44,11 +44,11 @@ AndersenThermostatImpl::AndersenThermostatImpl(AndersenThermostat& owner) : owne ...@@ -44,11 +44,11 @@ AndersenThermostatImpl::AndersenThermostatImpl(AndersenThermostat& owner) : owne
void AndersenThermostatImpl::initialize(ContextImpl& context) { void AndersenThermostatImpl::initialize(ContextImpl& context) {
kernel = context.getPlatform().createKernel(ApplyAndersenThermostatKernel::Name(), context); kernel = context.getPlatform().createKernel(ApplyAndersenThermostatKernel::Name(), context);
dynamic_cast<ApplyAndersenThermostatKernel&>(kernel.getImpl()).initialize(context.getSystem(), owner); kernel.getAs<ApplyAndersenThermostatKernel>().initialize(context.getSystem(), owner);
} }
void AndersenThermostatImpl::updateContextState(ContextImpl& context) { void AndersenThermostatImpl::updateContextState(ContextImpl& context) {
dynamic_cast<ApplyAndersenThermostatKernel&>(kernel.getImpl()).execute(context); kernel.getAs<ApplyAndersenThermostatKernel>().execute(context);
} }
std::map<std::string, double> AndersenThermostatImpl::getDefaultParameters() { std::map<std::string, double> AndersenThermostatImpl::getDefaultParameters() {
......
...@@ -55,7 +55,7 @@ void BrownianIntegrator::initialize(ContextImpl& contextRef) { ...@@ -55,7 +55,7 @@ void BrownianIntegrator::initialize(ContextImpl& contextRef) {
context = &contextRef; context = &contextRef;
owner = &contextRef.getOwner(); owner = &contextRef.getOwner();
kernel = context->getPlatform().createKernel(IntegrateBrownianStepKernel::Name(), contextRef); kernel = context->getPlatform().createKernel(IntegrateBrownianStepKernel::Name(), contextRef);
dynamic_cast<IntegrateBrownianStepKernel&>(kernel.getImpl()).initialize(contextRef.getSystem(), *this); kernel.getAs<IntegrateBrownianStepKernel>().initialize(contextRef.getSystem(), *this);
} }
vector<string> BrownianIntegrator::getKernelNames() { vector<string> BrownianIntegrator::getKernelNames() {
...@@ -68,6 +68,6 @@ void BrownianIntegrator::step(int steps) { ...@@ -68,6 +68,6 @@ void BrownianIntegrator::step(int steps) {
for (int i = 0; i < steps; ++i) { for (int i = 0; i < steps; ++i) {
context->updateContextState(); context->updateContextState();
context->calcForcesAndEnergy(true, false); context->calcForcesAndEnergy(true, false);
dynamic_cast<IntegrateBrownianStepKernel&>(kernel.getImpl()).execute(*context, *this); kernel.getAs<IntegrateBrownianStepKernel>().execute(*context, *this);
} }
} }
...@@ -49,12 +49,12 @@ CMAPTorsionForceImpl::~CMAPTorsionForceImpl() { ...@@ -49,12 +49,12 @@ CMAPTorsionForceImpl::~CMAPTorsionForceImpl() {
void CMAPTorsionForceImpl::initialize(ContextImpl& context) { void CMAPTorsionForceImpl::initialize(ContextImpl& context) {
kernel = context.getPlatform().createKernel(CalcCMAPTorsionForceKernel::Name(), context); kernel = context.getPlatform().createKernel(CalcCMAPTorsionForceKernel::Name(), context);
dynamic_cast<CalcCMAPTorsionForceKernel&>(kernel.getImpl()).initialize(context.getSystem(), owner); kernel.getAs<CalcCMAPTorsionForceKernel>().initialize(context.getSystem(), owner);
} }
double CMAPTorsionForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) { double CMAPTorsionForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) {
if ((groups&(1<<owner.getForceGroup())) != 0) if ((groups&(1<<owner.getForceGroup())) != 0)
return dynamic_cast<CalcCMAPTorsionForceKernel&>(kernel.getImpl()).execute(context, includeForces, includeEnergy); return kernel.getAs<CalcCMAPTorsionForceKernel>().execute(context, includeForces, includeEnergy);
return 0.0; return 0.0;
} }
......
...@@ -45,11 +45,11 @@ CMMotionRemoverImpl::CMMotionRemoverImpl(CMMotionRemover& owner) : owner(owner) ...@@ -45,11 +45,11 @@ CMMotionRemoverImpl::CMMotionRemoverImpl(CMMotionRemover& owner) : owner(owner)
void CMMotionRemoverImpl::initialize(ContextImpl& context) { void CMMotionRemoverImpl::initialize(ContextImpl& context) {
kernel = context.getPlatform().createKernel(RemoveCMMotionKernel::Name(), context); kernel = context.getPlatform().createKernel(RemoveCMMotionKernel::Name(), context);
const System& system = context.getSystem(); const System& system = context.getSystem();
dynamic_cast<RemoveCMMotionKernel&>(kernel.getImpl()).initialize(system, owner); kernel.getAs<RemoveCMMotionKernel>().initialize(system, owner);
} }
void CMMotionRemoverImpl::updateContextState(ContextImpl& context) { void CMMotionRemoverImpl::updateContextState(ContextImpl& context) {
dynamic_cast<RemoveCMMotionKernel&>(kernel.getImpl()).execute(context); kernel.getAs<RemoveCMMotionKernel>().execute(context);
} }
std::vector<std::string> CMMotionRemoverImpl::getKernelNames() { std::vector<std::string> CMMotionRemoverImpl::getKernelNames() {
......
...@@ -97,22 +97,22 @@ ContextImpl::ContextImpl(Context& owner, System& system, Integrator& integrator, ...@@ -97,22 +97,22 @@ ContextImpl::ContextImpl(Context& owner, System& system, Integrator& integrator,
platform->contextCreated(*this, properties); platform->contextCreated(*this, properties);
initializeForcesKernel = platform->createKernel(CalcForcesAndEnergyKernel::Name(), *this); initializeForcesKernel = platform->createKernel(CalcForcesAndEnergyKernel::Name(), *this);
dynamic_cast<CalcForcesAndEnergyKernel&>(initializeForcesKernel.getImpl()).initialize(system); initializeForcesKernel.getAs<CalcForcesAndEnergyKernel>().initialize(system);
kineticEnergyKernel = platform->createKernel(CalcKineticEnergyKernel::Name(), *this); kineticEnergyKernel = platform->createKernel(CalcKineticEnergyKernel::Name(), *this);
dynamic_cast<CalcKineticEnergyKernel&>(kineticEnergyKernel.getImpl()).initialize(system); kineticEnergyKernel.getAs<CalcKineticEnergyKernel>().initialize(system);
updateStateDataKernel = platform->createKernel(UpdateStateDataKernel::Name(), *this); updateStateDataKernel = platform->createKernel(UpdateStateDataKernel::Name(), *this);
dynamic_cast<UpdateStateDataKernel&>(updateStateDataKernel.getImpl()).initialize(system); updateStateDataKernel.getAs<UpdateStateDataKernel>().initialize(system);
applyConstraintsKernel = platform->createKernel(ApplyConstraintsKernel::Name(), *this); applyConstraintsKernel = platform->createKernel(ApplyConstraintsKernel::Name(), *this);
dynamic_cast<ApplyConstraintsKernel&>(applyConstraintsKernel.getImpl()).initialize(system); applyConstraintsKernel.getAs<ApplyConstraintsKernel>().initialize(system);
virtualSitesKernel = platform->createKernel(VirtualSitesKernel::Name(), *this); virtualSitesKernel = platform->createKernel(VirtualSitesKernel::Name(), *this);
dynamic_cast<VirtualSitesKernel&>(virtualSitesKernel.getImpl()).initialize(system); virtualSitesKernel.getAs<VirtualSitesKernel>().initialize(system);
Vec3 periodicBoxVectors[3]; Vec3 periodicBoxVectors[3];
system.getDefaultPeriodicBoxVectors(periodicBoxVectors[0], periodicBoxVectors[1], periodicBoxVectors[2]); system.getDefaultPeriodicBoxVectors(periodicBoxVectors[0], periodicBoxVectors[1], periodicBoxVectors[2]);
dynamic_cast<UpdateStateDataKernel&>(updateStateDataKernel.getImpl()).setPeriodicBoxVectors(*this, periodicBoxVectors[0], periodicBoxVectors[1], periodicBoxVectors[2]); updateStateDataKernel.getAs<UpdateStateDataKernel>().setPeriodicBoxVectors(*this, periodicBoxVectors[0], periodicBoxVectors[1], periodicBoxVectors[2]);
for (size_t i = 0; i < forceImpls.size(); ++i) for (size_t i = 0; i < forceImpls.size(); ++i)
forceImpls[i]->initialize(*this); forceImpls[i]->initialize(*this);
integrator.initialize(*this); integrator.initialize(*this);
dynamic_cast<UpdateStateDataKernel&>(updateStateDataKernel.getImpl()).setVelocities(*this, vector<Vec3>(system.getNumParticles())); updateStateDataKernel.getAs<UpdateStateDataKernel>().setVelocities(*this, vector<Vec3>(system.getNumParticles()));
} }
ContextImpl::~ContextImpl() { ContextImpl::~ContextImpl() {
...@@ -122,33 +122,33 @@ ContextImpl::~ContextImpl() { ...@@ -122,33 +122,33 @@ ContextImpl::~ContextImpl() {
} }
double ContextImpl::getTime() const { double ContextImpl::getTime() const {
return dynamic_cast<const UpdateStateDataKernel&>(updateStateDataKernel.getImpl()).getTime(*this); return updateStateDataKernel.getAs<const UpdateStateDataKernel>().getTime(*this);
} }
void ContextImpl::setTime(double t) { void ContextImpl::setTime(double t) {
dynamic_cast<UpdateStateDataKernel&>(updateStateDataKernel.getImpl()).setTime(*this, t); updateStateDataKernel.getAs<UpdateStateDataKernel>().setTime(*this, t);
} }
void ContextImpl::getPositions(std::vector<Vec3>& positions) { void ContextImpl::getPositions(std::vector<Vec3>& positions) {
dynamic_cast<UpdateStateDataKernel&>(updateStateDataKernel.getImpl()).getPositions(*this, positions); updateStateDataKernel.getAs<UpdateStateDataKernel>().getPositions(*this, positions);
} }
void ContextImpl::setPositions(const std::vector<Vec3>& positions) { void ContextImpl::setPositions(const std::vector<Vec3>& positions) {
dynamic_cast<UpdateStateDataKernel&>(updateStateDataKernel.getImpl()).setPositions(*this, positions); updateStateDataKernel.getAs<UpdateStateDataKernel>().setPositions(*this, positions);
integrator.stateChanged(State::Positions); integrator.stateChanged(State::Positions);
} }
void ContextImpl::getVelocities(std::vector<Vec3>& velocities) { void ContextImpl::getVelocities(std::vector<Vec3>& velocities) {
dynamic_cast<UpdateStateDataKernel&>(updateStateDataKernel.getImpl()).getVelocities(*this, velocities); updateStateDataKernel.getAs<UpdateStateDataKernel>().getVelocities(*this, velocities);
} }
void ContextImpl::setVelocities(const std::vector<Vec3>& velocities) { void ContextImpl::setVelocities(const std::vector<Vec3>& velocities) {
dynamic_cast<UpdateStateDataKernel&>(updateStateDataKernel.getImpl()).setVelocities(*this, velocities); updateStateDataKernel.getAs<UpdateStateDataKernel>().setVelocities(*this, velocities);
integrator.stateChanged(State::Velocities); integrator.stateChanged(State::Velocities);
} }
void ContextImpl::getForces(std::vector<Vec3>& forces) { void ContextImpl::getForces(std::vector<Vec3>& forces) {
dynamic_cast<UpdateStateDataKernel&>(updateStateDataKernel.getImpl()).getForces(*this, forces); updateStateDataKernel.getAs<UpdateStateDataKernel>().getForces(*this, forces);
} }
const std::map<std::string, double>& ContextImpl::getParameters() const { const std::map<std::string, double>& ContextImpl::getParameters() const {
...@@ -169,7 +169,7 @@ void ContextImpl::setParameter(std::string name, double value) { ...@@ -169,7 +169,7 @@ void ContextImpl::setParameter(std::string name, double value) {
} }
void ContextImpl::getPeriodicBoxVectors(Vec3& a, Vec3& b, Vec3& c) { void ContextImpl::getPeriodicBoxVectors(Vec3& a, Vec3& b, Vec3& c) {
dynamic_cast<UpdateStateDataKernel&>(updateStateDataKernel.getImpl()).getPeriodicBoxVectors(*this, a, b, c); updateStateDataKernel.getAs<UpdateStateDataKernel>().getPeriodicBoxVectors(*this, a, b, c);
} }
void ContextImpl::setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c) { void ContextImpl::setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3& c) {
...@@ -179,20 +179,20 @@ void ContextImpl::setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3 ...@@ -179,20 +179,20 @@ void ContextImpl::setPeriodicBoxVectors(const Vec3& a, const Vec3& b, const Vec3
throw OpenMMException("Second periodic box vector must be parallel to y."); throw OpenMMException("Second periodic box vector must be parallel to y.");
if (c[0] != 0.0 || c[1] != 0.0) if (c[0] != 0.0 || c[1] != 0.0)
throw OpenMMException("Third periodic box vector must be parallel to z."); throw OpenMMException("Third periodic box vector must be parallel to z.");
dynamic_cast<UpdateStateDataKernel&>(updateStateDataKernel.getImpl()).setPeriodicBoxVectors(*this, a, b, c); updateStateDataKernel.getAs<UpdateStateDataKernel>().setPeriodicBoxVectors(*this, a, b, c);
} }
void ContextImpl::applyConstraints(double tol) { void ContextImpl::applyConstraints(double tol) {
dynamic_cast<ApplyConstraintsKernel&>(applyConstraintsKernel.getImpl()).apply(*this, tol); applyConstraintsKernel.getAs<ApplyConstraintsKernel>().apply(*this, tol);
} }
void ContextImpl::computeVirtualSites() { void ContextImpl::computeVirtualSites() {
dynamic_cast<VirtualSitesKernel&>(virtualSitesKernel.getImpl()).computePositions(*this); virtualSitesKernel.getAs<VirtualSitesKernel>().computePositions(*this);
} }
double ContextImpl::calcForcesAndEnergy(bool includeForces, bool includeEnergy, int groups) { double ContextImpl::calcForcesAndEnergy(bool includeForces, bool includeEnergy, int groups) {
lastForceGroups = groups; lastForceGroups = groups;
CalcForcesAndEnergyKernel& kernel = dynamic_cast<CalcForcesAndEnergyKernel&>(initializeForcesKernel.getImpl()); CalcForcesAndEnergyKernel& kernel = initializeForcesKernel.getAs<CalcForcesAndEnergyKernel>();
double energy = 0.0; double energy = 0.0;
kernel.beginComputation(*this, includeForces, includeEnergy, groups); kernel.beginComputation(*this, includeForces, includeEnergy, groups);
for (int i = 0; i < (int) forceImpls.size(); ++i) for (int i = 0; i < (int) forceImpls.size(); ++i)
...@@ -206,7 +206,7 @@ int ContextImpl::getLastForceGroups() const { ...@@ -206,7 +206,7 @@ int ContextImpl::getLastForceGroups() const {
} }
double ContextImpl::calcKineticEnergy() { double ContextImpl::calcKineticEnergy() {
return dynamic_cast<CalcKineticEnergyKernel&>(kineticEnergyKernel.getImpl()).execute(*this); return kineticEnergyKernel.getAs<CalcKineticEnergyKernel>().execute(*this);
} }
void ContextImpl::updateContextState() { void ContextImpl::updateContextState() {
...@@ -312,7 +312,7 @@ void ContextImpl::createCheckpoint(ostream& stream) { ...@@ -312,7 +312,7 @@ void ContextImpl::createCheckpoint(ostream& stream) {
writeString(stream, iter->first); writeString(stream, iter->first);
stream.write((char*) &iter->second, sizeof(double)); stream.write((char*) &iter->second, sizeof(double));
} }
dynamic_cast<UpdateStateDataKernel&>(updateStateDataKernel.getImpl()).createCheckpoint(*this, stream); updateStateDataKernel.getAs<UpdateStateDataKernel>().createCheckpoint(*this, stream);
stream.flush(); stream.flush();
} }
...@@ -332,5 +332,5 @@ void ContextImpl::loadCheckpoint(istream& stream) { ...@@ -332,5 +332,5 @@ void ContextImpl::loadCheckpoint(istream& stream) {
stream.read((char*) &value, sizeof(double)); stream.read((char*) &value, sizeof(double));
parameters[name] = value; parameters[name] = value;
} }
dynamic_cast<UpdateStateDataKernel&>(updateStateDataKernel.getImpl()).loadCheckpoint(*this, stream); updateStateDataKernel.getAs<UpdateStateDataKernel>().loadCheckpoint(*this, stream);
} }
...@@ -85,12 +85,12 @@ void CustomAngleForceImpl::initialize(ContextImpl& context) { ...@@ -85,12 +85,12 @@ void CustomAngleForceImpl::initialize(ContextImpl& context) {
throw OpenMMException(msg.str()); throw OpenMMException(msg.str());
} }
} }
dynamic_cast<CalcCustomAngleForceKernel&>(kernel.getImpl()).initialize(context.getSystem(), owner); kernel.getAs<CalcCustomAngleForceKernel>().initialize(context.getSystem(), owner);
} }
double CustomAngleForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) { double CustomAngleForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) {
if ((groups&(1<<owner.getForceGroup())) != 0) if ((groups&(1<<owner.getForceGroup())) != 0)
return dynamic_cast<CalcCustomAngleForceKernel&>(kernel.getImpl()).execute(context, includeForces, includeEnergy); return kernel.getAs<CalcCustomAngleForceKernel>().execute(context, includeForces, includeEnergy);
return 0.0; return 0.0;
} }
......
...@@ -79,12 +79,12 @@ void CustomBondForceImpl::initialize(ContextImpl& context) { ...@@ -79,12 +79,12 @@ void CustomBondForceImpl::initialize(ContextImpl& context) {
throw OpenMMException(msg.str()); throw OpenMMException(msg.str());
} }
} }
dynamic_cast<CalcCustomBondForceKernel&>(kernel.getImpl()).initialize(context.getSystem(), owner); kernel.getAs<CalcCustomBondForceKernel>().initialize(context.getSystem(), owner);
} }
double CustomBondForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) { double CustomBondForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) {
if ((groups&(1<<owner.getForceGroup())) != 0) if ((groups&(1<<owner.getForceGroup())) != 0)
return dynamic_cast<CalcCustomBondForceKernel&>(kernel.getImpl()).execute(context, includeForces, includeEnergy); return kernel.getAs<CalcCustomBondForceKernel>().execute(context, includeForces, includeEnergy);
return 0.0; return 0.0;
} }
......
...@@ -102,12 +102,12 @@ void CustomCompoundBondForceImpl::initialize(ContextImpl& context) { ...@@ -102,12 +102,12 @@ void CustomCompoundBondForceImpl::initialize(ContextImpl& context) {
throw OpenMMException(msg.str()); throw OpenMMException(msg.str());
} }
} }
dynamic_cast<CalcCustomCompoundBondForceKernel&>(kernel.getImpl()).initialize(context.getSystem(), owner); kernel.getAs<CalcCustomCompoundBondForceKernel>().initialize(context.getSystem(), owner);
} }
double CustomCompoundBondForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) { double CustomCompoundBondForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) {
if ((groups&(1<<owner.getForceGroup())) != 0) if ((groups&(1<<owner.getForceGroup())) != 0)
return dynamic_cast<CalcCustomCompoundBondForceKernel&>(kernel.getImpl()).execute(context, includeForces, includeEnergy); return kernel.getAs<CalcCustomCompoundBondForceKernel>().execute(context, includeForces, includeEnergy);
return 0.0; return 0.0;
} }
......
...@@ -73,12 +73,12 @@ void CustomExternalForceImpl::initialize(ContextImpl& context) { ...@@ -73,12 +73,12 @@ void CustomExternalForceImpl::initialize(ContextImpl& context) {
throw OpenMMException(msg.str()); throw OpenMMException(msg.str());
} }
} }
dynamic_cast<CalcCustomExternalForceKernel&>(kernel.getImpl()).initialize(context.getSystem(), owner); kernel.getAs<CalcCustomExternalForceKernel>().initialize(context.getSystem(), owner);
} }
double CustomExternalForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) { double CustomExternalForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) {
if ((groups&(1<<owner.getForceGroup())) != 0) if ((groups&(1<<owner.getForceGroup())) != 0)
return dynamic_cast<CalcCustomExternalForceKernel&>(kernel.getImpl()).execute(context, includeForces, includeEnergy); return kernel.getAs<CalcCustomExternalForceKernel>().execute(context, includeForces, includeEnergy);
return 0.0; return 0.0;
} }
......
...@@ -102,12 +102,12 @@ void CustomGBForceImpl::initialize(ContextImpl& context) { ...@@ -102,12 +102,12 @@ void CustomGBForceImpl::initialize(ContextImpl& context) {
if (cutoff > 0.5*boxVectors[0][0] || cutoff > 0.5*boxVectors[1][1] || cutoff > 0.5*boxVectors[2][2]) if (cutoff > 0.5*boxVectors[0][0] || cutoff > 0.5*boxVectors[1][1] || cutoff > 0.5*boxVectors[2][2])
throw OpenMMException("CustomGBForce: The cutoff distance cannot be greater than half the periodic box size."); throw OpenMMException("CustomGBForce: The cutoff distance cannot be greater than half the periodic box size.");
} }
dynamic_cast<CalcCustomGBForceKernel&>(kernel.getImpl()).initialize(context.getSystem(), owner); kernel.getAs<CalcCustomGBForceKernel>().initialize(context.getSystem(), owner);
} }
double CustomGBForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) { double CustomGBForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) {
if ((groups&(1<<owner.getForceGroup())) != 0) if ((groups&(1<<owner.getForceGroup())) != 0)
return dynamic_cast<CalcCustomGBForceKernel&>(kernel.getImpl()).execute(context, includeForces, includeEnergy); return kernel.getAs<CalcCustomGBForceKernel>().execute(context, includeForces, includeEnergy);
return 0.0; return 0.0;
} }
......
...@@ -175,12 +175,12 @@ void CustomHbondForceImpl::initialize(ContextImpl& context) { ...@@ -175,12 +175,12 @@ void CustomHbondForceImpl::initialize(ContextImpl& context) {
if (cutoff > 0.5*boxVectors[0][0] || cutoff > 0.5*boxVectors[1][1] || cutoff > 0.5*boxVectors[2][2]) if (cutoff > 0.5*boxVectors[0][0] || cutoff > 0.5*boxVectors[1][1] || cutoff > 0.5*boxVectors[2][2])
throw OpenMMException("CustomHbondForce: The cutoff distance cannot be greater than half the periodic box size."); throw OpenMMException("CustomHbondForce: The cutoff distance cannot be greater than half the periodic box size.");
} }
dynamic_cast<CalcCustomHbondForceKernel&>(kernel.getImpl()).initialize(context.getSystem(), owner); kernel.getAs<CalcCustomHbondForceKernel>().initialize(context.getSystem(), owner);
} }
double CustomHbondForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) { double CustomHbondForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) {
if ((groups&(1<<owner.getForceGroup())) != 0) if ((groups&(1<<owner.getForceGroup())) != 0)
return dynamic_cast<CalcCustomHbondForceKernel&>(kernel.getImpl()).execute(context, includeForces, includeEnergy); return kernel.getAs<CalcCustomHbondForceKernel>().execute(context, includeForces, includeEnergy);
return 0.0; return 0.0;
} }
......
...@@ -54,12 +54,12 @@ void CustomIntegrator::initialize(ContextImpl& contextRef) { ...@@ -54,12 +54,12 @@ void CustomIntegrator::initialize(ContextImpl& contextRef) {
context = &contextRef; context = &contextRef;
owner = &contextRef.getOwner(); owner = &contextRef.getOwner();
kernel = context->getPlatform().createKernel(IntegrateCustomStepKernel::Name(), contextRef); kernel = context->getPlatform().createKernel(IntegrateCustomStepKernel::Name(), contextRef);
dynamic_cast<IntegrateCustomStepKernel&>(kernel.getImpl()).initialize(contextRef.getSystem(), *this); kernel.getAs<IntegrateCustomStepKernel>().initialize(contextRef.getSystem(), *this);
dynamic_cast<IntegrateCustomStepKernel&>(kernel.getImpl()).setGlobalVariables(contextRef, globalValues); kernel.getAs<IntegrateCustomStepKernel>().setGlobalVariables(contextRef, globalValues);
for (int i = 0; i < (int) perDofValues.size(); i++) { for (int i = 0; i < (int) perDofValues.size(); i++) {
if (perDofValues[i].size() == 1) if (perDofValues[i].size() == 1)
perDofValues[i].resize(context->getSystem().getNumParticles(), perDofValues[i][0]); perDofValues[i].resize(context->getSystem().getNumParticles(), perDofValues[i][0]);
dynamic_cast<IntegrateCustomStepKernel&>(kernel.getImpl()).setPerDofVariable(contextRef, i, perDofValues[i]); kernel.getAs<IntegrateCustomStepKernel>().setPerDofVariable(contextRef, i, perDofValues[i]);
} }
} }
...@@ -76,7 +76,7 @@ vector<string> CustomIntegrator::getKernelNames() { ...@@ -76,7 +76,7 @@ vector<string> CustomIntegrator::getKernelNames() {
void CustomIntegrator::step(int steps) { void CustomIntegrator::step(int steps) {
globalsAreCurrent = false; globalsAreCurrent = false;
for (int i = 0; i < steps; ++i) { for (int i = 0; i < steps; ++i) {
dynamic_cast<IntegrateCustomStepKernel&>(kernel.getImpl()).execute(*context, *this, forcesAreValid); kernel.getAs<IntegrateCustomStepKernel>().execute(*context, *this, forcesAreValid);
} }
} }
...@@ -109,7 +109,7 @@ const string& CustomIntegrator::getPerDofVariableName(int index) const { ...@@ -109,7 +109,7 @@ const string& CustomIntegrator::getPerDofVariableName(int index) const {
double CustomIntegrator::getGlobalVariable(int index) const { double CustomIntegrator::getGlobalVariable(int index) const {
ASSERT_VALID_INDEX(index, globalValues); ASSERT_VALID_INDEX(index, globalValues);
if (owner != NULL && !globalsAreCurrent) { if (owner != NULL && !globalsAreCurrent) {
dynamic_cast<const IntegrateCustomStepKernel&>(kernel.getImpl()).getGlobalVariables(*context, globalValues); kernel.getAs<const IntegrateCustomStepKernel>().getGlobalVariables(*context, globalValues);
globalsAreCurrent = true; globalsAreCurrent = true;
} }
return globalValues[index]; return globalValues[index];
...@@ -118,12 +118,12 @@ double CustomIntegrator::getGlobalVariable(int index) const { ...@@ -118,12 +118,12 @@ double CustomIntegrator::getGlobalVariable(int index) const {
void CustomIntegrator::setGlobalVariable(int index, double value) { void CustomIntegrator::setGlobalVariable(int index, double value) {
ASSERT_VALID_INDEX(index, globalValues); ASSERT_VALID_INDEX(index, globalValues);
if (owner != NULL && !globalsAreCurrent) { if (owner != NULL && !globalsAreCurrent) {
dynamic_cast<IntegrateCustomStepKernel&>(kernel.getImpl()).getGlobalVariables(*context, globalValues); kernel.getAs<IntegrateCustomStepKernel>().getGlobalVariables(*context, globalValues);
globalsAreCurrent = true; globalsAreCurrent = true;
} }
globalValues[index] = value; globalValues[index] = value;
if (owner != NULL) if (owner != NULL)
dynamic_cast<IntegrateCustomStepKernel&>(kernel.getImpl()).setGlobalVariables(*context, globalValues); kernel.getAs<IntegrateCustomStepKernel>().setGlobalVariables(*context, globalValues);
} }
void CustomIntegrator::setGlobalVariableByName(const string& name, double value) { void CustomIntegrator::setGlobalVariableByName(const string& name, double value) {
...@@ -140,7 +140,7 @@ void CustomIntegrator::getPerDofVariable(int index, vector<Vec3>& values) const ...@@ -140,7 +140,7 @@ void CustomIntegrator::getPerDofVariable(int index, vector<Vec3>& values) const
if (owner == NULL) if (owner == NULL)
values = perDofValues[index]; values = perDofValues[index];
else else
dynamic_cast<const IntegrateCustomStepKernel&>(kernel.getImpl()).getPerDofVariable(*context, index, values); kernel.getAs<const IntegrateCustomStepKernel>().getPerDofVariable(*context, index, values);
} }
void CustomIntegrator::setPerDofVariable(int index, const vector<Vec3>& values) { void CustomIntegrator::setPerDofVariable(int index, const vector<Vec3>& values) {
...@@ -150,7 +150,7 @@ void CustomIntegrator::setPerDofVariable(int index, const vector<Vec3>& values) ...@@ -150,7 +150,7 @@ void CustomIntegrator::setPerDofVariable(int index, const vector<Vec3>& values)
if (owner == NULL) if (owner == NULL)
perDofValues[index] = values; perDofValues[index] = values;
else else
dynamic_cast<IntegrateCustomStepKernel&>(kernel.getImpl()).setPerDofVariable(*context, index, values); kernel.getAs<IntegrateCustomStepKernel>().setPerDofVariable(*context, index, values);
} }
void CustomIntegrator::setPerDofVariableByName(const string& name, const vector<Vec3>& value) { void CustomIntegrator::setPerDofVariableByName(const string& name, const vector<Vec3>& value) {
......
...@@ -102,12 +102,12 @@ void CustomNonbondedForceImpl::initialize(ContextImpl& context) { ...@@ -102,12 +102,12 @@ void CustomNonbondedForceImpl::initialize(ContextImpl& context) {
if (cutoff > 0.5*boxVectors[0][0] || cutoff > 0.5*boxVectors[1][1] || cutoff > 0.5*boxVectors[2][2]) if (cutoff > 0.5*boxVectors[0][0] || cutoff > 0.5*boxVectors[1][1] || cutoff > 0.5*boxVectors[2][2])
throw OpenMMException("CustomNonbondedForce: The cutoff distance cannot be greater than half the periodic box size."); throw OpenMMException("CustomNonbondedForce: The cutoff distance cannot be greater than half the periodic box size.");
} }
dynamic_cast<CalcCustomNonbondedForceKernel&>(kernel.getImpl()).initialize(context.getSystem(), owner); kernel.getAs<CalcCustomNonbondedForceKernel>().initialize(context.getSystem(), owner);
} }
double CustomNonbondedForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) { double CustomNonbondedForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) {
if ((groups&(1<<owner.getForceGroup())) != 0) if ((groups&(1<<owner.getForceGroup())) != 0)
return dynamic_cast<CalcCustomNonbondedForceKernel&>(kernel.getImpl()).execute(context, includeForces, includeEnergy); return kernel.getAs<CalcCustomNonbondedForceKernel>().execute(context, includeForces, includeEnergy);
return 0.0; return 0.0;
} }
......
...@@ -91,12 +91,12 @@ void CustomTorsionForceImpl::initialize(ContextImpl& context) { ...@@ -91,12 +91,12 @@ void CustomTorsionForceImpl::initialize(ContextImpl& context) {
throw OpenMMException(msg.str()); throw OpenMMException(msg.str());
} }
} }
dynamic_cast<CalcCustomTorsionForceKernel&>(kernel.getImpl()).initialize(context.getSystem(), owner); kernel.getAs<CalcCustomTorsionForceKernel>().initialize(context.getSystem(), owner);
} }
double CustomTorsionForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) { double CustomTorsionForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) {
if ((groups&(1<<owner.getForceGroup())) != 0) if ((groups&(1<<owner.getForceGroup())) != 0)
return dynamic_cast<CalcCustomTorsionForceKernel&>(kernel.getImpl()).execute(context, includeForces, includeEnergy); return kernel.getAs<CalcCustomTorsionForceKernel>().execute(context, includeForces, includeEnergy);
return 0.0; return 0.0;
} }
......
...@@ -52,12 +52,12 @@ void GBSAOBCForceImpl::initialize(ContextImpl& context) { ...@@ -52,12 +52,12 @@ void GBSAOBCForceImpl::initialize(ContextImpl& context) {
if (cutoff > 0.5*boxVectors[0][0] || cutoff > 0.5*boxVectors[1][1] || cutoff > 0.5*boxVectors[2][2]) if (cutoff > 0.5*boxVectors[0][0] || cutoff > 0.5*boxVectors[1][1] || cutoff > 0.5*boxVectors[2][2])
throw OpenMMException("GBSAOBCForce: The cutoff distance cannot be greater than half the periodic box size."); throw OpenMMException("GBSAOBCForce: The cutoff distance cannot be greater than half the periodic box size.");
} }
dynamic_cast<CalcGBSAOBCForceKernel&>(kernel.getImpl()).initialize(context.getSystem(), owner); kernel.getAs<CalcGBSAOBCForceKernel>().initialize(context.getSystem(), owner);
} }
double GBSAOBCForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) { double GBSAOBCForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) {
if ((groups&(1<<owner.getForceGroup())) != 0) if ((groups&(1<<owner.getForceGroup())) != 0)
return dynamic_cast<CalcGBSAOBCForceKernel&>(kernel.getImpl()).execute(context, includeForces, includeEnergy); return kernel.getAs<CalcGBSAOBCForceKernel>().execute(context, includeForces, includeEnergy);
return 0.0; return 0.0;
} }
......
...@@ -106,7 +106,7 @@ void GBVIForceImpl::initialize(ContextImpl& context) { ...@@ -106,7 +106,7 @@ void GBVIForceImpl::initialize(ContextImpl& context) {
scaledRadii.resize(numberOfParticles); scaledRadii.resize(numberOfParticles);
findScaledRadii( numberOfParticles, bondIndices, bondLengths, scaledRadii); findScaledRadii( numberOfParticles, bondIndices, bondLengths, scaledRadii);
dynamic_cast<CalcGBVIForceKernel&>(kernel.getImpl()).initialize(context.getSystem(), owner, scaledRadii); kernel.getAs<CalcGBVIForceKernel>().initialize(context.getSystem(), owner, scaledRadii);
} }
int GBVIForceImpl::getBondsFromForces(ContextImpl& context) { int GBVIForceImpl::getBondsFromForces(ContextImpl& context) {
...@@ -241,7 +241,7 @@ void GBVIForceImpl::findScaledRadii( int numberOfParticles, const std::vector<st ...@@ -241,7 +241,7 @@ void GBVIForceImpl::findScaledRadii( int numberOfParticles, const std::vector<st
double GBVIForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) { double GBVIForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) {
if ((groups&(1<<owner.getForceGroup())) != 0) if ((groups&(1<<owner.getForceGroup())) != 0)
return dynamic_cast<CalcGBVIForceKernel&>(kernel.getImpl()).execute(context, includeForces, includeEnergy); return kernel.getAs<CalcGBVIForceKernel>().execute(context, includeForces, includeEnergy);
return 0.0; return 0.0;
} }
......
...@@ -46,12 +46,12 @@ HarmonicAngleForceImpl::~HarmonicAngleForceImpl() { ...@@ -46,12 +46,12 @@ HarmonicAngleForceImpl::~HarmonicAngleForceImpl() {
void HarmonicAngleForceImpl::initialize(ContextImpl& context) { void HarmonicAngleForceImpl::initialize(ContextImpl& context) {
kernel = context.getPlatform().createKernel(CalcHarmonicAngleForceKernel::Name(), context); kernel = context.getPlatform().createKernel(CalcHarmonicAngleForceKernel::Name(), context);
dynamic_cast<CalcHarmonicAngleForceKernel&>(kernel.getImpl()).initialize(context.getSystem(), owner); kernel.getAs<CalcHarmonicAngleForceKernel>().initialize(context.getSystem(), owner);
} }
double HarmonicAngleForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) { double HarmonicAngleForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) {
if ((groups&(1<<owner.getForceGroup())) != 0) if ((groups&(1<<owner.getForceGroup())) != 0)
return dynamic_cast<CalcHarmonicAngleForceKernel&>(kernel.getImpl()).execute(context, includeForces, includeEnergy); return kernel.getAs<CalcHarmonicAngleForceKernel>().execute(context, includeForces, includeEnergy);
return 0.0; return 0.0;
} }
......
...@@ -46,12 +46,12 @@ HarmonicBondForceImpl::~HarmonicBondForceImpl() { ...@@ -46,12 +46,12 @@ HarmonicBondForceImpl::~HarmonicBondForceImpl() {
void HarmonicBondForceImpl::initialize(ContextImpl& context) { void HarmonicBondForceImpl::initialize(ContextImpl& context) {
kernel = context.getPlatform().createKernel(CalcHarmonicBondForceKernel::Name(), context); kernel = context.getPlatform().createKernel(CalcHarmonicBondForceKernel::Name(), context);
dynamic_cast<CalcHarmonicBondForceKernel&>(kernel.getImpl()).initialize(context.getSystem(), owner); kernel.getAs<CalcHarmonicBondForceKernel>().initialize(context.getSystem(), owner);
} }
double HarmonicBondForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) { double HarmonicBondForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) {
if ((groups&(1<<owner.getForceGroup())) != 0) if ((groups&(1<<owner.getForceGroup())) != 0)
return dynamic_cast<CalcHarmonicBondForceKernel&>(kernel.getImpl()).execute(context, includeForces, includeEnergy); return kernel.getAs<CalcHarmonicBondForceKernel>().execute(context, includeForces, includeEnergy);
return 0.0; return 0.0;
} }
......
...@@ -55,7 +55,7 @@ void LangevinIntegrator::initialize(ContextImpl& contextRef) { ...@@ -55,7 +55,7 @@ void LangevinIntegrator::initialize(ContextImpl& contextRef) {
context = &contextRef; context = &contextRef;
owner = &contextRef.getOwner(); owner = &contextRef.getOwner();
kernel = context->getPlatform().createKernel(IntegrateLangevinStepKernel::Name(), contextRef); kernel = context->getPlatform().createKernel(IntegrateLangevinStepKernel::Name(), contextRef);
dynamic_cast<IntegrateLangevinStepKernel&>(kernel.getImpl()).initialize(contextRef.getSystem(), *this); kernel.getAs<IntegrateLangevinStepKernel>().initialize(contextRef.getSystem(), *this);
} }
vector<string> LangevinIntegrator::getKernelNames() { vector<string> LangevinIntegrator::getKernelNames() {
...@@ -68,6 +68,6 @@ void LangevinIntegrator::step(int steps) { ...@@ -68,6 +68,6 @@ void LangevinIntegrator::step(int steps) {
for (int i = 0; i < steps; ++i) { for (int i = 0; i < steps; ++i) {
context->updateContextState(); context->updateContextState();
context->calcForcesAndEnergy(true, false); context->calcForcesAndEnergy(true, false);
dynamic_cast<IntegrateLangevinStepKernel&>(kernel.getImpl()).execute(*context, *this); kernel.getAs<IntegrateLangevinStepKernel>().execute(*context, *this);
} }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment