Commit b4044e73 authored by Peter Eastman's avatar Peter Eastman
Browse files

Created templatized interface for Kernel, which made a lot of code that invokes kernels cleaner.

parent fcdba25c
...@@ -45,12 +45,12 @@ void GBSAOBCSoftcoreForceImpl::initialize(ContextImpl& context) { ...@@ -45,12 +45,12 @@ void GBSAOBCSoftcoreForceImpl::initialize(ContextImpl& context) {
kernel = context.getPlatform().createKernel(CalcGBSAOBCSoftcoreForceKernel::Name(), context); kernel = context.getPlatform().createKernel(CalcGBSAOBCSoftcoreForceKernel::Name(), context);
if (owner.getNumParticles() != context.getSystem().getNumParticles()) if (owner.getNumParticles() != context.getSystem().getNumParticles())
throw OpenMMException("GBSAOBCSoftcoreForce must have exactly as many particles as the System it belongs to."); throw OpenMMException("GBSAOBCSoftcoreForce must have exactly as many particles as the System it belongs to.");
dynamic_cast<CalcGBSAOBCSoftcoreForceKernel&>(kernel.getImpl()).initialize(context.getSystem(), owner); kernel.getAs<CalcGBSAOBCSoftcoreForceKernel>().initialize(context.getSystem(), owner);
} }
double GBSAOBCSoftcoreForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) { double GBSAOBCSoftcoreForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) {
if ((groups&(1<<owner.getForceGroup())) != 0) if ((groups&(1<<owner.getForceGroup())) != 0)
return dynamic_cast<CalcGBSAOBCSoftcoreForceKernel&>(kernel.getImpl()).execute(context, includeForces, includeEnergy); return kernel.getAs<CalcGBSAOBCSoftcoreForceKernel>().execute(context, includeForces, includeEnergy);
return 0.0; return 0.0;
} }
......
...@@ -111,7 +111,7 @@ void GBVISoftcoreForceImpl::initialize(ContextImpl& context) { ...@@ -111,7 +111,7 @@ void GBVISoftcoreForceImpl::initialize(ContextImpl& context) {
scaledRadii.resize(numberOfParticles); scaledRadii.resize(numberOfParticles);
findScaledRadii( numberOfParticles, bondIndices, bondLengths, scaledRadii); findScaledRadii( numberOfParticles, bondIndices, bondLengths, scaledRadii);
dynamic_cast<CalcGBVISoftcoreForceKernel&>(kernel.getImpl()).initialize(context.getSystem(), owner, scaledRadii); kernel.getAs<CalcGBVISoftcoreForceKernel>().initialize(context.getSystem(), owner, scaledRadii);
} }
int GBVISoftcoreForceImpl::getBondsFromForces(ContextImpl& context) { int GBVISoftcoreForceImpl::getBondsFromForces(ContextImpl& context) {
...@@ -263,7 +263,7 @@ void GBVISoftcoreForceImpl::findScaledRadii( int numberOfParticles, const std::v ...@@ -263,7 +263,7 @@ void GBVISoftcoreForceImpl::findScaledRadii( int numberOfParticles, const std::v
double GBVISoftcoreForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) { double GBVISoftcoreForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) {
if ((groups&(1<<owner.getForceGroup())) != 0) if ((groups&(1<<owner.getForceGroup())) != 0)
return dynamic_cast<CalcGBVISoftcoreForceKernel&>(kernel.getImpl()).execute(context, includeForces, includeEnergy); return kernel.getAs<CalcGBVISoftcoreForceKernel>().execute(context, includeForces, includeEnergy);
return 0.0; return 0.0;
} }
......
...@@ -84,12 +84,12 @@ void NonbondedSoftcoreForceImpl::initialize(ContextImpl& context) { ...@@ -84,12 +84,12 @@ void NonbondedSoftcoreForceImpl::initialize(ContextImpl& context) {
exceptions[particle1].insert(particle2); exceptions[particle1].insert(particle2);
exceptions[particle2].insert(particle1); exceptions[particle2].insert(particle1);
} }
dynamic_cast<CalcNonbondedSoftcoreForceKernel&>(kernel.getImpl()).initialize(context.getSystem(), owner); kernel.getAs<CalcNonbondedSoftcoreForceKernel>().initialize(context.getSystem(), owner);
} }
double NonbondedSoftcoreForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) { double NonbondedSoftcoreForceImpl::calcForcesAndEnergy(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) {
if ((groups&(1<<owner.getForceGroup())) != 0) if ((groups&(1<<owner.getForceGroup())) != 0)
return dynamic_cast<CalcNonbondedSoftcoreForceKernel&>(kernel.getImpl()).execute(context, includeForces, includeEnergy); return kernel.getAs<CalcNonbondedSoftcoreForceKernel>().execute(context, includeForces, includeEnergy);
return 0.0; return 0.0;
} }
......
...@@ -58,7 +58,7 @@ void RPMDIntegrator::initialize(ContextImpl& contextRef) { ...@@ -58,7 +58,7 @@ void RPMDIntegrator::initialize(ContextImpl& contextRef) {
context = &contextRef; context = &contextRef;
owner = &contextRef.getOwner(); owner = &contextRef.getOwner();
kernel = context->getPlatform().createKernel(IntegrateRPMDStepKernel::Name(), contextRef); kernel = context->getPlatform().createKernel(IntegrateRPMDStepKernel::Name(), contextRef);
dynamic_cast<IntegrateRPMDStepKernel&>(kernel.getImpl()).initialize(contextRef.getSystem(), *this); kernel.getAs<IntegrateRPMDStepKernel>().initialize(contextRef.getSystem(), *this);
} }
void RPMDIntegrator::stateChanged(State::DataType changed) { void RPMDIntegrator::stateChanged(State::DataType changed) {
...@@ -72,17 +72,17 @@ vector<string> RPMDIntegrator::getKernelNames() { ...@@ -72,17 +72,17 @@ vector<string> RPMDIntegrator::getKernelNames() {
} }
void RPMDIntegrator::setPositions(int copy, const vector<Vec3>& positions) { void RPMDIntegrator::setPositions(int copy, const vector<Vec3>& positions) {
dynamic_cast<IntegrateRPMDStepKernel&>(kernel.getImpl()).setPositions(copy, positions); kernel.getAs<IntegrateRPMDStepKernel>().setPositions(copy, positions);
hasSetPosition = true; hasSetPosition = true;
} }
void RPMDIntegrator::setVelocities(int copy, const vector<Vec3>& velocities) { void RPMDIntegrator::setVelocities(int copy, const vector<Vec3>& velocities) {
dynamic_cast<IntegrateRPMDStepKernel&>(kernel.getImpl()).setVelocities(copy, velocities); kernel.getAs<IntegrateRPMDStepKernel>().setVelocities(copy, velocities);
hasSetVelocity = true; hasSetVelocity = true;
} }
State RPMDIntegrator::getState(int copy, int types, bool enforcePeriodicBox, int groups) { State RPMDIntegrator::getState(int copy, int types, bool enforcePeriodicBox, int groups) {
dynamic_cast<IntegrateRPMDStepKernel&>(kernel.getImpl()).copyToContext(copy, *context); kernel.getAs<IntegrateRPMDStepKernel>().copyToContext(copy, *context);
return context->getOwner().getState(types, enforcePeriodicBox, groups); return context->getOwner().getState(types, enforcePeriodicBox, groups);
} }
...@@ -102,7 +102,7 @@ void RPMDIntegrator::step(int steps) { ...@@ -102,7 +102,7 @@ void RPMDIntegrator::step(int steps) {
setVelocities(i, s.getVelocities()); setVelocities(i, s.getVelocities());
} }
for (int i = 0; i < steps; ++i) { for (int i = 0; i < steps; ++i) {
dynamic_cast<IntegrateRPMDStepKernel&>(kernel.getImpl()).execute(*context, *this, forcesAreValid); kernel.getAs<IntegrateRPMDStepKernel>().execute(*context, *this, forcesAreValid);
forcesAreValid = true; forcesAreValid = true;
} }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment