Commit ff238051 authored by peastman's avatar peastman
Browse files

OpenCL implementation of vector functions for CustomIntegrator

parent ad285083
...@@ -1484,9 +1484,8 @@ class OpenCLIntegrateCustomStepKernel : public IntegrateCustomStepKernel { ...@@ -1484,9 +1484,8 @@ class OpenCLIntegrateCustomStepKernel : public IntegrateCustomStepKernel {
public: public:
enum GlobalTargetType {DT, VARIABLE, PARAMETER}; enum GlobalTargetType {DT, VARIABLE, PARAMETER};
OpenCLIntegrateCustomStepKernel(std::string name, const Platform& platform, OpenCLContext& cl) : IntegrateCustomStepKernel(name, platform), cl(cl), OpenCLIntegrateCustomStepKernel(std::string name, const Platform& platform, OpenCLContext& cl) : IntegrateCustomStepKernel(name, platform), cl(cl),
hasInitializedKernels(false), localValuesAreCurrent(false), perDofValues(NULL), needsEnergyParamDerivs(false) { hasInitializedKernels(false), needsEnergyParamDerivs(false) {
} }
~OpenCLIntegrateCustomStepKernel();
/** /**
* Initialize the kernel. * Initialize the kernel.
* *
...@@ -1550,7 +1549,7 @@ private: ...@@ -1550,7 +1549,7 @@ private:
class ReorderListener; class ReorderListener;
class GlobalTarget; class GlobalTarget;
class DerivFunction; class DerivFunction;
std::string createPerDofComputation(const std::string& variable, const Lepton::ParsedExpression& expr, int component, CustomIntegrator& integrator, std::string createPerDofComputation(const std::string& variable, const Lepton::ParsedExpression& expr, CustomIntegrator& integrator,
const std::string& forceName, const std::string& energyName, std::vector<const TabulatedFunction*>& functions, const std::string& forceName, const std::string& energyName, std::vector<const TabulatedFunction*>& functions,
std::vector<std::pair<std::string, std::string> >& functionNames); std::vector<std::pair<std::string, std::string> >& functionNames);
void prepareForComputation(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid); void prepareForComputation(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid);
...@@ -1563,21 +1562,21 @@ private: ...@@ -1563,21 +1562,21 @@ private:
double energy; double energy;
float energyFloat; float energyFloat;
int numGlobalVariables, sumWorkGroupSize; int numGlobalVariables, sumWorkGroupSize;
bool hasInitializedKernels, deviceValuesAreCurrent, deviceGlobalsAreCurrent, modifiesParameters, keNeedsForce, hasAnyConstraints, needsEnergyParamDerivs; bool hasInitializedKernels, deviceGlobalsAreCurrent, modifiesParameters, keNeedsForce, hasAnyConstraints, needsEnergyParamDerivs;
mutable bool localValuesAreCurrent; std::vector<bool> deviceValuesAreCurrent;
mutable std::vector<bool> localValuesAreCurrent;
OpenCLArray globalValues; OpenCLArray globalValues;
OpenCLArray sumBuffer; OpenCLArray sumBuffer;
OpenCLArray summedValue; OpenCLArray summedValue;
OpenCLArray uniformRandoms; OpenCLArray uniformRandoms;
OpenCLArray randomSeed; OpenCLArray randomSeed;
OpenCLArray perDofEnergyParamDerivs; OpenCLArray perDofEnergyParamDerivs;
std::vector<OpenCLArray> tabulatedFunctions; std::vector<OpenCLArray> tabulatedFunctions, perDofValues;
std::map<int, double> savedEnergy; std::map<int, double> savedEnergy;
std::map<int, OpenCLArray> savedForces; std::map<int, OpenCLArray> savedForces;
std::set<int> validSavedForces; std::set<int> validSavedForces;
OpenCLParameterSet* perDofValues; mutable std::vector<std::vector<mm_float4> > localPerDofValuesFloat;
mutable std::vector<std::vector<cl_float> > localPerDofValuesFloat; mutable std::vector<std::vector<mm_double4> > localPerDofValuesDouble;
mutable std::vector<std::vector<cl_double> > localPerDofValuesDouble;
std::map<std::string, double> energyParamDerivs; std::map<std::string, double> energyParamDerivs;
std::vector<std::string> perDofEnergyParamDerivNames; std::vector<std::string> perDofEnergyParamDerivNames;
std::vector<cl_float> localPerDofEnergyParamDerivsFloat; std::vector<cl_float> localPerDofEnergyParamDerivsFloat;
......
This diff is collapsed.
...@@ -102,7 +102,6 @@ void testSingleBond() { ...@@ -102,7 +102,6 @@ void testSingleBond() {
*/ */
void testConstraints() { void testConstraints() {
const int numParticles = 8; const int numParticles = 8;
const double temp = 500.0;
System system; System system;
CustomIntegrator integrator(0.002); CustomIntegrator integrator(0.002);
integrator.addPerDofVariable("oldx", 0); integrator.addPerDofVariable("oldx", 0);
...@@ -1028,8 +1027,10 @@ void testVectorFunctions() { ...@@ -1028,8 +1027,10 @@ void testVectorFunctions() {
CustomIntegrator integrator(0.001); CustomIntegrator integrator(0.001);
integrator.addGlobalVariable("sumy", 0.0); integrator.addGlobalVariable("sumy", 0.0);
integrator.addPerDofVariable("angular", 0.0); integrator.addPerDofVariable("angular", 0.0);
integrator.addPerDofVariable("shuffle", 0.0);
integrator.addComputeSum("sumy", "x*vector(0, 1, 0)"); integrator.addComputeSum("sumy", "x*vector(0, 1, 0)");
integrator.addComputePerDof("angular", "cross(v, x)"); integrator.addComputePerDof("angular", "cross(v, x)");
integrator.addComputePerDof("shuffle", "dot(vector(_z(x), _x(x), _y(x)), v)");
OpenMM_SFMT::SFMT sfmt; OpenMM_SFMT::SFMT sfmt;
init_gen_rand(0, sfmt); init_gen_rand(0, sfmt);
vector<Vec3> positions(numParticles); vector<Vec3> positions(numParticles);
...@@ -1047,10 +1048,12 @@ void testVectorFunctions() { ...@@ -1047,10 +1048,12 @@ void testVectorFunctions() {
// See if the expressions were computed correctly. // See if the expressions were computed correctly.
double sumy = 0; double sumy = 0;
vector<Vec3> values; vector<Vec3> angular, shuffle;
integrator.getPerDofVariable(0, values); integrator.getPerDofVariable(0, angular);
integrator.getPerDofVariable(1, shuffle);
for (int i = 0; i < numParticles; i++) { for (int i = 0; i < numParticles; i++) {
ASSERT_EQUAL_VEC(velocities[i].cross(positions[i]), values[i], 1e-5); ASSERT_EQUAL_VEC(velocities[i].cross(positions[i]), angular[i], 1e-5);
ASSERT_EQUAL_VEC(Vec3(1, 1, 1)*velocities[i].dot(Vec3(positions[i][2], positions[i][0], positions[i][1])), shuffle[i], 1e-5);
sumy += positions[i][1]; sumy += positions[i][1];
} }
ASSERT_EQUAL_TOL(sumy, integrator.getGlobalVariable(0), 1e-5); ASSERT_EQUAL_TOL(sumy, integrator.getGlobalVariable(0), 1e-5);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment