Unverified Commit 756e479c authored by peastman's avatar peastman Committed by GitHub
Browse files

Merge pull request #2129 from peastman/ixn

Optimization to custom forces on CPU
parents 2985aaea 89f52648
/* Portions copyright (c) 2010-2016 Stanford University and Simbios. /* Portions copyright (c) 2010-2018 Stanford University and Simbios.
* Contributors: Peter Eastman * Contributors: Peter Eastman
* *
* Permission is hereby granted, free of charge, to any person obtaining * Permission is hereby granted, free of charge, to any person obtaining
...@@ -51,7 +51,7 @@ class ReferenceCustomAngleIxn : public ReferenceBondIxn { ...@@ -51,7 +51,7 @@ class ReferenceCustomAngleIxn : public ReferenceBondIxn {
--------------------------------------------------------------------------------------- */ --------------------------------------------------------------------------------------- */
ReferenceCustomAngleIxn(const Lepton::CompiledExpression& energyExpression, const Lepton::CompiledExpression& forceExpression, ReferenceCustomAngleIxn(const Lepton::CompiledExpression& energyExpression, const Lepton::CompiledExpression& forceExpression,
const std::vector<std::string>& parameterNames, std::map<std::string, double> globalParameters, const std::vector<std::string>& parameterNames,
const std::vector<Lepton::CompiledExpression> energyParamDerivExpressions); const std::vector<Lepton::CompiledExpression> energyParamDerivExpressions);
/**--------------------------------------------------------------------------------------- /**---------------------------------------------------------------------------------------
...@@ -72,6 +72,14 @@ class ReferenceCustomAngleIxn : public ReferenceBondIxn { ...@@ -72,6 +72,14 @@ class ReferenceCustomAngleIxn : public ReferenceBondIxn {
void setPeriodic(OpenMM::Vec3* vectors); void setPeriodic(OpenMM::Vec3* vectors);
/**---------------------------------------------------------------------------------------
Set the values of all global parameters.
--------------------------------------------------------------------------------------- */
void setGlobalParameters(std::map<std::string, double> parameters);
/**--------------------------------------------------------------------------------------- /**---------------------------------------------------------------------------------------
Calculate Custom Angle Ixn Calculate Custom Angle Ixn
......
/* Portions copyright (c) 2009-2016 Stanford University and Simbios. /* Portions copyright (c) 2009-2018 Stanford University and Simbios.
* Contributors: Peter Eastman * Contributors: Peter Eastman
* *
* Permission is hereby granted, free of charge, to any person obtaining * Permission is hereby granted, free of charge, to any person obtaining
...@@ -52,7 +52,7 @@ class ReferenceCustomBondIxn : public ReferenceBondIxn { ...@@ -52,7 +52,7 @@ class ReferenceCustomBondIxn : public ReferenceBondIxn {
--------------------------------------------------------------------------------------- */ --------------------------------------------------------------------------------------- */
ReferenceCustomBondIxn(const Lepton::CompiledExpression& energyExpression, const Lepton::CompiledExpression& forceExpression, ReferenceCustomBondIxn(const Lepton::CompiledExpression& energyExpression, const Lepton::CompiledExpression& forceExpression,
const std::vector<std::string>& parameterNames, std::map<std::string, double> globalParameters, const std::vector<std::string>& parameterNames,
const std::vector<Lepton::CompiledExpression> energyParamDerivExpressions); const std::vector<Lepton::CompiledExpression> energyParamDerivExpressions);
/**--------------------------------------------------------------------------------------- /**---------------------------------------------------------------------------------------
...@@ -73,6 +73,14 @@ class ReferenceCustomBondIxn : public ReferenceBondIxn { ...@@ -73,6 +73,14 @@ class ReferenceCustomBondIxn : public ReferenceBondIxn {
void setPeriodic(OpenMM::Vec3* vectors); void setPeriodic(OpenMM::Vec3* vectors);
/**---------------------------------------------------------------------------------------
Set the values of all global parameters.
--------------------------------------------------------------------------------------- */
void setGlobalParameters(std::map<std::string, double> parameters);
/**--------------------------------------------------------------------------------------- /**---------------------------------------------------------------------------------------
Calculate Custom Bond Ixn Calculate Custom Bond Ixn
......
...@@ -58,7 +58,7 @@ class ReferenceCustomExternalIxn { ...@@ -58,7 +58,7 @@ class ReferenceCustomExternalIxn {
ReferenceCustomExternalIxn(const Lepton::CompiledExpression& energyExpression, const Lepton::CompiledExpression& forceExpressionX, ReferenceCustomExternalIxn(const Lepton::CompiledExpression& energyExpression, const Lepton::CompiledExpression& forceExpressionX,
const Lepton::CompiledExpression& forceExpressionY, const Lepton::CompiledExpression& forceExpressionZ, const Lepton::CompiledExpression& forceExpressionY, const Lepton::CompiledExpression& forceExpressionZ,
const std::vector<std::string>& parameterNames, std::map<std::string, double> globalParameters); const std::vector<std::string>& parameterNames);
/**--------------------------------------------------------------------------------------- /**---------------------------------------------------------------------------------------
...@@ -68,6 +68,14 @@ class ReferenceCustomExternalIxn { ...@@ -68,6 +68,14 @@ class ReferenceCustomExternalIxn {
~ReferenceCustomExternalIxn(); ~ReferenceCustomExternalIxn();
/**---------------------------------------------------------------------------------------
Set the values of all global parameters.
--------------------------------------------------------------------------------------- */
void setGlobalParameters(std::map<std::string, double> parameters);
/**--------------------------------------------------------------------------------------- /**---------------------------------------------------------------------------------------
Calculate Custom External Force Calculate Custom External Force
......
/* Portions copyright (c) 2010-2016 Stanford University and Simbios. /* Portions copyright (c) 2010-2018 Stanford University and Simbios.
* Contributors: Peter Eastman * Contributors: Peter Eastman
* *
* Permission is hereby granted, free of charge, to any person obtaining * Permission is hereby granted, free of charge, to any person obtaining
...@@ -51,7 +51,7 @@ class ReferenceCustomTorsionIxn : public ReferenceBondIxn { ...@@ -51,7 +51,7 @@ class ReferenceCustomTorsionIxn : public ReferenceBondIxn {
--------------------------------------------------------------------------------------- */ --------------------------------------------------------------------------------------- */
ReferenceCustomTorsionIxn(const Lepton::CompiledExpression& energyExpression, const Lepton::CompiledExpression& forceExpression, ReferenceCustomTorsionIxn(const Lepton::CompiledExpression& energyExpression, const Lepton::CompiledExpression& forceExpression,
const std::vector<std::string>& parameterNames, std::map<std::string, double> globalParameters, const std::vector<std::string>& parameterNames,
const std::vector<Lepton::CompiledExpression> energyParamDerivExpressions); const std::vector<Lepton::CompiledExpression> energyParamDerivExpressions);
/**--------------------------------------------------------------------------------------- /**---------------------------------------------------------------------------------------
...@@ -72,6 +72,14 @@ class ReferenceCustomTorsionIxn : public ReferenceBondIxn { ...@@ -72,6 +72,14 @@ class ReferenceCustomTorsionIxn : public ReferenceBondIxn {
void setPeriodic(OpenMM::Vec3* vectors); void setPeriodic(OpenMM::Vec3* vectors);
/**---------------------------------------------------------------------------------------
Set the values of all global parameters.
--------------------------------------------------------------------------------------- */
void setGlobalParameters(std::map<std::string, double> parameters);
/**--------------------------------------------------------------------------------------- /**---------------------------------------------------------------------------------------
Calculate Custom Torsion Ixn Calculate Custom Torsion Ixn
......
...@@ -45,6 +45,10 @@ namespace OpenMM { ...@@ -45,6 +45,10 @@ namespace OpenMM {
class ReferenceObc; class ReferenceObc;
class ReferenceAndersenThermostat; class ReferenceAndersenThermostat;
class ReferenceCustomBondIxn;
class ReferenceCustomAngleIxn;
class ReferenceCustomTorsionIxn;
class ReferenceCustomExternalIxn;
class ReferenceCustomCentroidBondIxn; class ReferenceCustomCentroidBondIxn;
class ReferenceCustomCompoundBondIxn; class ReferenceCustomCompoundBondIxn;
class ReferenceCustomCVForce; class ReferenceCustomCVForce;
...@@ -296,8 +300,9 @@ private: ...@@ -296,8 +300,9 @@ private:
*/ */
class ReferenceCalcCustomBondForceKernel : public CalcCustomBondForceKernel { class ReferenceCalcCustomBondForceKernel : public CalcCustomBondForceKernel {
public: public:
ReferenceCalcCustomBondForceKernel(std::string name, const Platform& platform) : CalcCustomBondForceKernel(name, platform) { ReferenceCalcCustomBondForceKernel(std::string name, const Platform& platform) : CalcCustomBondForceKernel(name, platform), ixn(NULL) {
} }
~ReferenceCalcCustomBondForceKernel();
/** /**
* Initialize the kernel. * Initialize the kernel.
* *
...@@ -323,6 +328,7 @@ public: ...@@ -323,6 +328,7 @@ public:
void copyParametersToContext(ContextImpl& context, const CustomBondForce& force); void copyParametersToContext(ContextImpl& context, const CustomBondForce& force);
private: private:
int numBonds; int numBonds;
ReferenceCustomBondIxn* ixn;
std::vector<std::vector<int> >bondIndexArray; std::vector<std::vector<int> >bondIndexArray;
std::vector<std::vector<double> >bondParamArray; std::vector<std::vector<double> >bondParamArray;
Lepton::CompiledExpression energyExpression, forceExpression; Lepton::CompiledExpression energyExpression, forceExpression;
...@@ -373,8 +379,9 @@ private: ...@@ -373,8 +379,9 @@ private:
*/ */
class ReferenceCalcCustomAngleForceKernel : public CalcCustomAngleForceKernel { class ReferenceCalcCustomAngleForceKernel : public CalcCustomAngleForceKernel {
public: public:
ReferenceCalcCustomAngleForceKernel(std::string name, const Platform& platform) : CalcCustomAngleForceKernel(name, platform) { ReferenceCalcCustomAngleForceKernel(std::string name, const Platform& platform) : CalcCustomAngleForceKernel(name, platform), ixn(NULL) {
} }
~ReferenceCalcCustomAngleForceKernel();
/** /**
* Initialize the kernel. * Initialize the kernel.
* *
...@@ -400,6 +407,7 @@ public: ...@@ -400,6 +407,7 @@ public:
void copyParametersToContext(ContextImpl& context, const CustomAngleForce& force); void copyParametersToContext(ContextImpl& context, const CustomAngleForce& force);
private: private:
int numAngles; int numAngles;
ReferenceCustomAngleIxn* ixn;
std::vector<std::vector<int> >angleIndexArray; std::vector<std::vector<int> >angleIndexArray;
std::vector<std::vector<double> >angleParamArray; std::vector<std::vector<double> >angleParamArray;
Lepton::CompiledExpression energyExpression, forceExpression; Lepton::CompiledExpression energyExpression, forceExpression;
...@@ -524,8 +532,9 @@ private: ...@@ -524,8 +532,9 @@ private:
*/ */
class ReferenceCalcCustomTorsionForceKernel : public CalcCustomTorsionForceKernel { class ReferenceCalcCustomTorsionForceKernel : public CalcCustomTorsionForceKernel {
public: public:
ReferenceCalcCustomTorsionForceKernel(std::string name, const Platform& platform) : CalcCustomTorsionForceKernel(name, platform) { ReferenceCalcCustomTorsionForceKernel(std::string name, const Platform& platform) : CalcCustomTorsionForceKernel(name, platform), ixn(NULL) {
} }
~ReferenceCalcCustomTorsionForceKernel();
/** /**
* Initialize the kernel. * Initialize the kernel.
* *
...@@ -551,6 +560,7 @@ public: ...@@ -551,6 +560,7 @@ public:
void copyParametersToContext(ContextImpl& context, const CustomTorsionForce& force); void copyParametersToContext(ContextImpl& context, const CustomTorsionForce& force);
private: private:
int numTorsions; int numTorsions;
ReferenceCustomTorsionIxn* ixn;
std::vector<std::vector<int> >torsionIndexArray; std::vector<std::vector<int> >torsionIndexArray;
std::vector<std::vector<double> >torsionParamArray; std::vector<std::vector<double> >torsionParamArray;
Lepton::CompiledExpression energyExpression, forceExpression; Lepton::CompiledExpression energyExpression, forceExpression;
...@@ -766,8 +776,9 @@ private: ...@@ -766,8 +776,9 @@ private:
*/ */
class ReferenceCalcCustomExternalForceKernel : public CalcCustomExternalForceKernel { class ReferenceCalcCustomExternalForceKernel : public CalcCustomExternalForceKernel {
public: public:
ReferenceCalcCustomExternalForceKernel(std::string name, const Platform& platform) : CalcCustomExternalForceKernel(name, platform) { ReferenceCalcCustomExternalForceKernel(std::string name, const Platform& platform) : CalcCustomExternalForceKernel(name, platform), ixn(NULL) {
} }
~ReferenceCalcCustomExternalForceKernel();
/** /**
* Initialize the kernel. * Initialize the kernel.
* *
...@@ -794,6 +805,7 @@ public: ...@@ -794,6 +805,7 @@ public:
private: private:
class PeriodicDistanceFunction; class PeriodicDistanceFunction;
int numParticles; int numParticles;
ReferenceCustomExternalIxn* ixn;
std::vector<int> particles; std::vector<int> particles;
std::vector<std::vector<double> > particleParamArray; std::vector<std::vector<double> > particleParamArray;
Lepton::CompiledExpression energyExpression, forceExpressionX, forceExpressionY, forceExpressionZ; Lepton::CompiledExpression energyExpression, forceExpressionX, forceExpressionY, forceExpressionZ;
......
...@@ -383,6 +383,11 @@ void ReferenceCalcHarmonicBondForceKernel::copyParametersToContext(ContextImpl& ...@@ -383,6 +383,11 @@ void ReferenceCalcHarmonicBondForceKernel::copyParametersToContext(ContextImpl&
} }
} }
ReferenceCalcCustomBondForceKernel::~ReferenceCalcCustomBondForceKernel() {
if (ixn != NULL)
delete ixn;
}
void ReferenceCalcCustomBondForceKernel::initialize(const System& system, const CustomBondForce& force) { void ReferenceCalcCustomBondForceKernel::initialize(const System& system, const CustomBondForce& force) {
numBonds = force.getNumBonds(); numBonds = force.getNumBonds();
int numParameters = force.getNumPerBondParameters(); int numParameters = force.getNumPerBondParameters();
...@@ -421,6 +426,7 @@ void ReferenceCalcCustomBondForceKernel::initialize(const System& system, const ...@@ -421,6 +426,7 @@ void ReferenceCalcCustomBondForceKernel::initialize(const System& system, const
variables.insert(parameterNames.begin(), parameterNames.end()); variables.insert(parameterNames.begin(), parameterNames.end());
variables.insert(globalParameterNames.begin(), globalParameterNames.end()); variables.insert(globalParameterNames.begin(), globalParameterNames.end());
validateVariables(expression.getRootNode(), variables); validateVariables(expression.getRootNode(), variables);
ixn = new ReferenceCustomBondIxn(energyExpression, forceExpression, parameterNames, energyParamDerivExpressions);
} }
double ReferenceCalcCustomBondForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) { double ReferenceCalcCustomBondForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
...@@ -430,12 +436,12 @@ double ReferenceCalcCustomBondForceKernel::execute(ContextImpl& context, bool in ...@@ -430,12 +436,12 @@ double ReferenceCalcCustomBondForceKernel::execute(ContextImpl& context, bool in
map<string, double> globalParameters; map<string, double> globalParameters;
for (auto& name : globalParameterNames) for (auto& name : globalParameterNames)
globalParameters[name] = context.getParameter(name); globalParameters[name] = context.getParameter(name);
ReferenceCustomBondIxn bond(energyExpression, forceExpression, parameterNames, globalParameters, energyParamDerivExpressions); ixn->setGlobalParameters(globalParameters);
if (usePeriodic) if (usePeriodic)
bond.setPeriodic(extractBoxVectors(context)); ixn->setPeriodic(extractBoxVectors(context));
vector<double> energyParamDerivValues(energyParamDerivNames.size()+1, 0.0); vector<double> energyParamDerivValues(energyParamDerivNames.size()+1, 0.0);
for (int i = 0; i < numBonds; i++) for (int i = 0; i < numBonds; i++)
bond.calculateBondIxn(bondIndexArray[i], posData, bondParamArray[i], forceData, includeEnergy ? &energy : NULL, &energyParamDerivValues[0]); ixn->calculateBondIxn(bondIndexArray[i], posData, bondParamArray[i], forceData, includeEnergy ? &energy : NULL, &energyParamDerivValues[0]);
map<string, double>& energyParamDerivs = extractEnergyParameterDerivatives(context); map<string, double>& energyParamDerivs = extractEnergyParameterDerivatives(context);
for (int i = 0; i < energyParamDerivNames.size(); i++) for (int i = 0; i < energyParamDerivNames.size(); i++)
energyParamDerivs[energyParamDerivNames[i]] += energyParamDerivValues[i]; energyParamDerivs[energyParamDerivNames[i]] += energyParamDerivValues[i];
...@@ -506,6 +512,11 @@ void ReferenceCalcHarmonicAngleForceKernel::copyParametersToContext(ContextImpl& ...@@ -506,6 +512,11 @@ void ReferenceCalcHarmonicAngleForceKernel::copyParametersToContext(ContextImpl&
} }
} }
ReferenceCalcCustomAngleForceKernel::~ReferenceCalcCustomAngleForceKernel() {
if (ixn != NULL)
delete ixn;
}
void ReferenceCalcCustomAngleForceKernel::initialize(const System& system, const CustomAngleForce& force) { void ReferenceCalcCustomAngleForceKernel::initialize(const System& system, const CustomAngleForce& force) {
numAngles = force.getNumAngles(); numAngles = force.getNumAngles();
int numParameters = force.getNumPerAngleParameters(); int numParameters = force.getNumPerAngleParameters();
...@@ -545,6 +556,7 @@ void ReferenceCalcCustomAngleForceKernel::initialize(const System& system, const ...@@ -545,6 +556,7 @@ void ReferenceCalcCustomAngleForceKernel::initialize(const System& system, const
variables.insert(parameterNames.begin(), parameterNames.end()); variables.insert(parameterNames.begin(), parameterNames.end());
variables.insert(globalParameterNames.begin(), globalParameterNames.end()); variables.insert(globalParameterNames.begin(), globalParameterNames.end());
validateVariables(expression.getRootNode(), variables); validateVariables(expression.getRootNode(), variables);
ixn = new ReferenceCustomAngleIxn(energyExpression, forceExpression, parameterNames, energyParamDerivExpressions);
} }
double ReferenceCalcCustomAngleForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) { double ReferenceCalcCustomAngleForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
...@@ -554,12 +566,12 @@ double ReferenceCalcCustomAngleForceKernel::execute(ContextImpl& context, bool i ...@@ -554,12 +566,12 @@ double ReferenceCalcCustomAngleForceKernel::execute(ContextImpl& context, bool i
map<string, double> globalParameters; map<string, double> globalParameters;
for (auto& name : globalParameterNames) for (auto& name : globalParameterNames)
globalParameters[name] = context.getParameter(name); globalParameters[name] = context.getParameter(name);
ReferenceCustomAngleIxn customAngle(energyExpression, forceExpression, parameterNames, globalParameters, energyParamDerivExpressions); ixn->setGlobalParameters(globalParameters);
if (usePeriodic) if (usePeriodic)
customAngle.setPeriodic(extractBoxVectors(context)); ixn->setPeriodic(extractBoxVectors(context));
vector<double> energyParamDerivValues(energyParamDerivNames.size()+1, 0.0); vector<double> energyParamDerivValues(energyParamDerivNames.size()+1, 0.0);
for (int i = 0; i < numAngles; i++) for (int i = 0; i < numAngles; i++)
customAngle.calculateBondIxn(angleIndexArray[i], posData, angleParamArray[i], forceData, includeEnergy ? &energy : NULL, &energyParamDerivValues[0]); ixn->calculateBondIxn(angleIndexArray[i], posData, angleParamArray[i], forceData, includeEnergy ? &energy : NULL, &energyParamDerivValues[0]);
map<string, double>& energyParamDerivs = extractEnergyParameterDerivatives(context); map<string, double>& energyParamDerivs = extractEnergyParameterDerivatives(context);
for (int i = 0; i < energyParamDerivNames.size(); i++) for (int i = 0; i < energyParamDerivNames.size(); i++)
energyParamDerivs[energyParamDerivNames[i]] += energyParamDerivValues[i]; energyParamDerivs[energyParamDerivNames[i]] += energyParamDerivValues[i];
...@@ -760,6 +772,11 @@ void ReferenceCalcCMAPTorsionForceKernel::copyParametersToContext(ContextImpl& c ...@@ -760,6 +772,11 @@ void ReferenceCalcCMAPTorsionForceKernel::copyParametersToContext(ContextImpl& c
} }
} }
ReferenceCalcCustomTorsionForceKernel::~ReferenceCalcCustomTorsionForceKernel() {
if (ixn != NULL)
delete ixn;
}
void ReferenceCalcCustomTorsionForceKernel::initialize(const System& system, const CustomTorsionForce& force) { void ReferenceCalcCustomTorsionForceKernel::initialize(const System& system, const CustomTorsionForce& force) {
numTorsions = force.getNumTorsions(); numTorsions = force.getNumTorsions();
int numParameters = force.getNumPerTorsionParameters(); int numParameters = force.getNumPerTorsionParameters();
...@@ -800,6 +817,7 @@ void ReferenceCalcCustomTorsionForceKernel::initialize(const System& system, con ...@@ -800,6 +817,7 @@ void ReferenceCalcCustomTorsionForceKernel::initialize(const System& system, con
variables.insert(parameterNames.begin(), parameterNames.end()); variables.insert(parameterNames.begin(), parameterNames.end());
variables.insert(globalParameterNames.begin(), globalParameterNames.end()); variables.insert(globalParameterNames.begin(), globalParameterNames.end());
validateVariables(expression.getRootNode(), variables); validateVariables(expression.getRootNode(), variables);
ixn = new ReferenceCustomTorsionIxn(energyExpression, forceExpression, parameterNames, energyParamDerivExpressions);
} }
double ReferenceCalcCustomTorsionForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) { double ReferenceCalcCustomTorsionForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
...@@ -809,12 +827,12 @@ double ReferenceCalcCustomTorsionForceKernel::execute(ContextImpl& context, bool ...@@ -809,12 +827,12 @@ double ReferenceCalcCustomTorsionForceKernel::execute(ContextImpl& context, bool
map<string, double> globalParameters; map<string, double> globalParameters;
for (auto& name : globalParameterNames) for (auto& name : globalParameterNames)
globalParameters[name] = context.getParameter(name); globalParameters[name] = context.getParameter(name);
ReferenceCustomTorsionIxn customTorsion(energyExpression, forceExpression, parameterNames, globalParameters, energyParamDerivExpressions); ixn->setGlobalParameters(globalParameters);
if (usePeriodic) if (usePeriodic)
customTorsion.setPeriodic(extractBoxVectors(context)); ixn->setPeriodic(extractBoxVectors(context));
vector<double> energyParamDerivValues(energyParamDerivNames.size()+1, 0.0); vector<double> energyParamDerivValues(energyParamDerivNames.size()+1, 0.0);
for (int i = 0; i < numTorsions; i++) for (int i = 0; i < numTorsions; i++)
customTorsion.calculateBondIxn(torsionIndexArray[i], posData, torsionParamArray[i], forceData, includeEnergy ? &energy : NULL, &energyParamDerivValues[0]); ixn->calculateBondIxn(torsionIndexArray[i], posData, torsionParamArray[i], forceData, includeEnergy ? &energy : NULL, &energyParamDerivValues[0]);
map<string, double>& energyParamDerivs = extractEnergyParameterDerivatives(context); map<string, double>& energyParamDerivs = extractEnergyParameterDerivatives(context);
for (int i = 0; i < energyParamDerivNames.size(); i++) for (int i = 0; i < energyParamDerivNames.size(); i++)
energyParamDerivs[energyParamDerivNames[i]] += energyParamDerivValues[i]; energyParamDerivs[energyParamDerivNames[i]] += energyParamDerivValues[i];
...@@ -1515,6 +1533,11 @@ Lepton::CustomFunction* ReferenceCalcCustomExternalForceKernel::PeriodicDistance ...@@ -1515,6 +1533,11 @@ Lepton::CustomFunction* ReferenceCalcCustomExternalForceKernel::PeriodicDistance
return new PeriodicDistanceFunction(boxVectorHandle); return new PeriodicDistanceFunction(boxVectorHandle);
} }
ReferenceCalcCustomExternalForceKernel::~ReferenceCalcCustomExternalForceKernel() {
if (ixn != NULL)
delete ixn;
}
void ReferenceCalcCustomExternalForceKernel::initialize(const System& system, const CustomExternalForce& force) { void ReferenceCalcCustomExternalForceKernel::initialize(const System& system, const CustomExternalForce& force) {
numParticles = force.getNumParticles(); numParticles = force.getNumParticles();
int numParameters = force.getNumPerParticleParameters(); int numParameters = force.getNumPerParticleParameters();
...@@ -1547,6 +1570,8 @@ void ReferenceCalcCustomExternalForceKernel::initialize(const System& system, co ...@@ -1547,6 +1570,8 @@ void ReferenceCalcCustomExternalForceKernel::initialize(const System& system, co
variables.insert(parameterNames.begin(), parameterNames.end()); variables.insert(parameterNames.begin(), parameterNames.end());
variables.insert(globalParameterNames.begin(), globalParameterNames.end()); variables.insert(globalParameterNames.begin(), globalParameterNames.end());
validateVariables(expression.getRootNode(), variables); validateVariables(expression.getRootNode(), variables);
ixn = new ReferenceCustomExternalIxn(energyExpression, forceExpressionX, forceExpressionY, forceExpressionZ, parameterNames);
} }
double ReferenceCalcCustomExternalForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) { double ReferenceCalcCustomExternalForceKernel::execute(ContextImpl& context, bool includeForces, bool includeEnergy) {
...@@ -1557,9 +1582,9 @@ double ReferenceCalcCustomExternalForceKernel::execute(ContextImpl& context, boo ...@@ -1557,9 +1582,9 @@ double ReferenceCalcCustomExternalForceKernel::execute(ContextImpl& context, boo
map<string, double> globalParameters; map<string, double> globalParameters;
for (auto& name : globalParameterNames) for (auto& name : globalParameterNames)
globalParameters[name] = context.getParameter(name); globalParameters[name] = context.getParameter(name);
ReferenceCustomExternalIxn force(energyExpression, forceExpressionX, forceExpressionY, forceExpressionZ, parameterNames, globalParameters); ixn->setGlobalParameters(globalParameters);
for (int i = 0; i < numParticles; ++i) for (int i = 0; i < numParticles; ++i)
force.calculateForce(particles[i], posData, particleParamArray[i], forceData, includeEnergy ? &energy : NULL); ixn->calculateForce(particles[i], posData, particleParamArray[i], forceData, includeEnergy ? &energy : NULL);
return energy; return energy;
} }
......
/* Portions copyright (c) 2010-2016 Stanford University and Simbios. /* Portions copyright (c) 2010-2018 Stanford University and Simbios.
* Contributors: Peter Eastman * Contributors: Peter Eastman
* *
* Permission is hereby granted, free of charge, to any person obtaining * Permission is hereby granted, free of charge, to any person obtaining
...@@ -38,7 +38,7 @@ using namespace std; ...@@ -38,7 +38,7 @@ using namespace std;
--------------------------------------------------------------------------------------- */ --------------------------------------------------------------------------------------- */
ReferenceCustomAngleIxn::ReferenceCustomAngleIxn(const Lepton::CompiledExpression& energyExpression, ReferenceCustomAngleIxn::ReferenceCustomAngleIxn(const Lepton::CompiledExpression& energyExpression,
const Lepton::CompiledExpression& forceExpression, const vector<string>& parameterNames, map<string, double> globalParameters, const Lepton::CompiledExpression& forceExpression, const vector<string>& parameterNames,
const vector<Lepton::CompiledExpression> energyParamDerivExpressions) : const vector<Lepton::CompiledExpression> energyParamDerivExpressions) :
energyExpression(energyExpression), forceExpression(forceExpression), usePeriodic(false), energyParamDerivExpressions(energyParamDerivExpressions) { energyExpression(energyExpression), forceExpression(forceExpression), usePeriodic(false), energyParamDerivExpressions(energyParamDerivExpressions) {
expressionSet.registerExpression(this->energyExpression); expressionSet.registerExpression(this->energyExpression);
...@@ -49,8 +49,6 @@ ReferenceCustomAngleIxn::ReferenceCustomAngleIxn(const Lepton::CompiledExpressio ...@@ -49,8 +49,6 @@ ReferenceCustomAngleIxn::ReferenceCustomAngleIxn(const Lepton::CompiledExpressio
numParameters = parameterNames.size(); numParameters = parameterNames.size();
for (auto& param : parameterNames) for (auto& param : parameterNames)
angleParamIndex.push_back(expressionSet.getVariableIndex(param)); angleParamIndex.push_back(expressionSet.getVariableIndex(param));
for (auto& param : globalParameters)
expressionSet.setVariable(expressionSet.getVariableIndex(param.first), param.second);
} }
/**--------------------------------------------------------------------------------------- /**---------------------------------------------------------------------------------------
...@@ -69,6 +67,11 @@ void ReferenceCustomAngleIxn::setPeriodic(OpenMM::Vec3* vectors) { ...@@ -69,6 +67,11 @@ void ReferenceCustomAngleIxn::setPeriodic(OpenMM::Vec3* vectors) {
boxVectors[2] = vectors[2]; boxVectors[2] = vectors[2];
} }
void ReferenceCustomAngleIxn::setGlobalParameters(std::map<std::string, double> parameters) {
for (auto& param : parameters)
expressionSet.setVariable(expressionSet.getVariableIndex(param.first), param.second);
}
/**--------------------------------------------------------------------------------------- /**---------------------------------------------------------------------------------------
Calculate Custom Angle Ixn Calculate Custom Angle Ixn
......
/* Portions copyright (c) 2009-2016 Stanford University and Simbios. /* Portions copyright (c) 2009-2018 Stanford University and Simbios.
* Contributors: Peter Eastman * Contributors: Peter Eastman
* *
* Permission is hereby granted, free of charge, to any person obtaining * Permission is hereby granted, free of charge, to any person obtaining
...@@ -39,7 +39,7 @@ using namespace OpenMM; ...@@ -39,7 +39,7 @@ using namespace OpenMM;
--------------------------------------------------------------------------------------- */ --------------------------------------------------------------------------------------- */
ReferenceCustomBondIxn::ReferenceCustomBondIxn(const Lepton::CompiledExpression& energyExpression, ReferenceCustomBondIxn::ReferenceCustomBondIxn(const Lepton::CompiledExpression& energyExpression,
const Lepton::CompiledExpression& forceExpression, const vector<string>& parameterNames, map<string, double> globalParameters, const Lepton::CompiledExpression& forceExpression, const vector<string>& parameterNames,
const vector<Lepton::CompiledExpression> energyParamDerivExpressions) : const vector<Lepton::CompiledExpression> energyParamDerivExpressions) :
energyExpression(energyExpression), forceExpression(forceExpression), usePeriodic(false), energyParamDerivExpressions(energyParamDerivExpressions) { energyExpression(energyExpression), forceExpression(forceExpression), usePeriodic(false), energyParamDerivExpressions(energyParamDerivExpressions) {
expressionSet.registerExpression(this->energyExpression); expressionSet.registerExpression(this->energyExpression);
...@@ -50,8 +50,6 @@ ReferenceCustomBondIxn::ReferenceCustomBondIxn(const Lepton::CompiledExpression& ...@@ -50,8 +50,6 @@ ReferenceCustomBondIxn::ReferenceCustomBondIxn(const Lepton::CompiledExpression&
numParameters = parameterNames.size(); numParameters = parameterNames.size();
for (auto& param : parameterNames) for (auto& param : parameterNames)
bondParamIndex.push_back(expressionSet.getVariableIndex(param)); bondParamIndex.push_back(expressionSet.getVariableIndex(param));
for (auto& param : globalParameters)
expressionSet.setVariable(expressionSet.getVariableIndex(param.first), param.second);
} }
/**--------------------------------------------------------------------------------------- /**---------------------------------------------------------------------------------------
...@@ -70,6 +68,11 @@ void ReferenceCustomBondIxn::setPeriodic(OpenMM::Vec3* vectors) { ...@@ -70,6 +68,11 @@ void ReferenceCustomBondIxn::setPeriodic(OpenMM::Vec3* vectors) {
boxVectors[2] = vectors[2]; boxVectors[2] = vectors[2];
} }
void ReferenceCustomBondIxn::setGlobalParameters(std::map<std::string, double> parameters) {
for (auto& param : parameters)
expressionSet.setVariable(expressionSet.getVariableIndex(param.first), param.second);
}
/**--------------------------------------------------------------------------------------- /**---------------------------------------------------------------------------------------
Calculate Custom Bond Ixn Calculate Custom Bond Ixn
......
...@@ -40,7 +40,7 @@ using namespace OpenMM; ...@@ -40,7 +40,7 @@ using namespace OpenMM;
ReferenceCustomExternalIxn::ReferenceCustomExternalIxn(const Lepton::CompiledExpression& energyExpression, ReferenceCustomExternalIxn::ReferenceCustomExternalIxn(const Lepton::CompiledExpression& energyExpression,
const Lepton::CompiledExpression& forceExpressionX, const Lepton::CompiledExpression& forceExpressionY, const Lepton::CompiledExpression& forceExpressionX, const Lepton::CompiledExpression& forceExpressionY,
const Lepton::CompiledExpression& forceExpressionZ, const vector<string>& parameterNames, map<string, double> globalParameters) : const Lepton::CompiledExpression& forceExpressionZ, const vector<string>& parameterNames) :
energyExpression(energyExpression), forceExpressionX(forceExpressionX), forceExpressionY(forceExpressionY), energyExpression(energyExpression), forceExpressionX(forceExpressionX), forceExpressionY(forceExpressionY),
forceExpressionZ(forceExpressionZ) { forceExpressionZ(forceExpressionZ) {
...@@ -63,12 +63,6 @@ ReferenceCustomExternalIxn::ReferenceCustomExternalIxn(const Lepton::CompiledExp ...@@ -63,12 +63,6 @@ ReferenceCustomExternalIxn::ReferenceCustomExternalIxn(const Lepton::CompiledExp
forceYParams.push_back(ReferenceForce::getVariablePointer(this->forceExpressionY, param)); forceYParams.push_back(ReferenceForce::getVariablePointer(this->forceExpressionY, param));
forceZParams.push_back(ReferenceForce::getVariablePointer(this->forceExpressionZ, param)); forceZParams.push_back(ReferenceForce::getVariablePointer(this->forceExpressionZ, param));
} }
for (auto& param : globalParameters) {
ReferenceForce::setVariable(ReferenceForce::getVariablePointer(this->energyExpression, param.first), param.second);
ReferenceForce::setVariable(ReferenceForce::getVariablePointer(this->forceExpressionX, param.first), param.second);
ReferenceForce::setVariable(ReferenceForce::getVariablePointer(this->forceExpressionY, param.first), param.second);
ReferenceForce::setVariable(ReferenceForce::getVariablePointer(this->forceExpressionZ, param.first), param.second);
}
} }
/**--------------------------------------------------------------------------------------- /**---------------------------------------------------------------------------------------
...@@ -80,6 +74,15 @@ ReferenceCustomExternalIxn::ReferenceCustomExternalIxn(const Lepton::CompiledExp ...@@ -80,6 +74,15 @@ ReferenceCustomExternalIxn::ReferenceCustomExternalIxn(const Lepton::CompiledExp
ReferenceCustomExternalIxn::~ReferenceCustomExternalIxn() { ReferenceCustomExternalIxn::~ReferenceCustomExternalIxn() {
} }
void ReferenceCustomExternalIxn::setGlobalParameters(std::map<std::string, double> parameters) {
for (auto& param : parameters) {
ReferenceForce::setVariable(ReferenceForce::getVariablePointer(this->energyExpression, param.first), param.second);
ReferenceForce::setVariable(ReferenceForce::getVariablePointer(this->forceExpressionX, param.first), param.second);
ReferenceForce::setVariable(ReferenceForce::getVariablePointer(this->forceExpressionY, param.first), param.second);
ReferenceForce::setVariable(ReferenceForce::getVariablePointer(this->forceExpressionZ, param.first), param.second);
}
}
/**--------------------------------------------------------------------------------------- /**---------------------------------------------------------------------------------------
Calculate Custom External Ixn Calculate Custom External Ixn
......
/* Portions copyright (c) 2010-2016 Stanford University and Simbios. /* Portions copyright (c) 2010-2018 Stanford University and Simbios.
* Contributors: Peter Eastman * Contributors: Peter Eastman
* *
* Permission is hereby granted, free of charge, to any person obtaining * Permission is hereby granted, free of charge, to any person obtaining
...@@ -38,7 +38,7 @@ using namespace OpenMM; ...@@ -38,7 +38,7 @@ using namespace OpenMM;
--------------------------------------------------------------------------------------- */ --------------------------------------------------------------------------------------- */
ReferenceCustomTorsionIxn::ReferenceCustomTorsionIxn(const Lepton::CompiledExpression& energyExpression, ReferenceCustomTorsionIxn::ReferenceCustomTorsionIxn(const Lepton::CompiledExpression& energyExpression,
const Lepton::CompiledExpression& forceExpression, const vector<string>& parameterNames, map<string, double> globalParameters, const Lepton::CompiledExpression& forceExpression, const vector<string>& parameterNames,
const vector<Lepton::CompiledExpression> energyParamDerivExpressions) : const vector<Lepton::CompiledExpression> energyParamDerivExpressions) :
energyExpression(energyExpression), forceExpression(forceExpression), usePeriodic(false), energyParamDerivExpressions(energyParamDerivExpressions) { energyExpression(energyExpression), forceExpression(forceExpression), usePeriodic(false), energyParamDerivExpressions(energyParamDerivExpressions) {
expressionSet.registerExpression(this->energyExpression); expressionSet.registerExpression(this->energyExpression);
...@@ -49,8 +49,6 @@ ReferenceCustomTorsionIxn::ReferenceCustomTorsionIxn(const Lepton::CompiledExpre ...@@ -49,8 +49,6 @@ ReferenceCustomTorsionIxn::ReferenceCustomTorsionIxn(const Lepton::CompiledExpre
numParameters = parameterNames.size(); numParameters = parameterNames.size();
for (auto& param : parameterNames) for (auto& param : parameterNames)
torsionParamIndex.push_back(expressionSet.getVariableIndex(param)); torsionParamIndex.push_back(expressionSet.getVariableIndex(param));
for (auto& param : globalParameters)
expressionSet.setVariable(expressionSet.getVariableIndex(param.first), param.second);
} }
/**--------------------------------------------------------------------------------------- /**---------------------------------------------------------------------------------------
...@@ -69,6 +67,11 @@ void ReferenceCustomTorsionIxn::setPeriodic(OpenMM::Vec3* vectors) { ...@@ -69,6 +67,11 @@ void ReferenceCustomTorsionIxn::setPeriodic(OpenMM::Vec3* vectors) {
boxVectors[2] = vectors[2]; boxVectors[2] = vectors[2];
} }
void ReferenceCustomTorsionIxn::setGlobalParameters(std::map<std::string, double> parameters) {
for (auto& param : parameters)
expressionSet.setVariable(expressionSet.getVariableIndex(param.first), param.second);
}
/**--------------------------------------------------------------------------------------- /**---------------------------------------------------------------------------------------
Calculate Custom Torsion Ixn Calculate Custom Torsion Ixn
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment