"...ssh:/git@developer.sourcefind.cn:2222/tsoc/openmm.git" did not exist on "cc03f111d4f4535fe12d2fe8f0fe77e2ee4cc8e7"
Commit eaef52d9 authored by peastman's avatar peastman
Browse files

Created OpenCL implementation of periodicdistance()

parent 91a8cc49
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2009-2014 Stanford University and the Authors. * * Portions copyright (c) 2009-2015 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -89,6 +89,10 @@ public: ...@@ -89,6 +89,10 @@ public:
* @param function the function for which to get a placeholder * @param function the function for which to get a placeholder
*/ */
Lepton::CustomFunction* getFunctionPlaceholder(const TabulatedFunction& function); Lepton::CustomFunction* getFunctionPlaceholder(const TabulatedFunction& function);
/**
* Get a Lepton::CustomFunction that can be used to represent the periodicdistance() function when parsing expressions.
*/
Lepton::CustomFunction* getPeriodicDistancePlaceholder();
private: private:
class FunctionPlaceholder : public Lepton::CustomFunction { class FunctionPlaceholder : public Lepton::CustomFunction {
public: public:
...@@ -114,13 +118,13 @@ private: ...@@ -114,13 +118,13 @@ private:
const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames, const std::vector<const TabulatedFunction*>& functions, const std::vector<std::pair<std::string, std::string> >& functionNames,
const std::string& prefix, const std::vector<std::vector<double> >& functionParams, const std::vector<Lepton::ParsedExpression>& allExpressions, const std::string& tempType); const std::string& prefix, const std::vector<std::vector<double> >& functionParams, const std::vector<Lepton::ParsedExpression>& allExpressions, const std::string& tempType);
std::string getTempName(const Lepton::ExpressionTreeNode& node, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps); std::string getTempName(const Lepton::ExpressionTreeNode& node, const std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& temps);
void findRelatedTabulatedFunctions(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode, void findRelatedCustomFunctions(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode,
std::vector<const Lepton::ExpressionTreeNode*>& nodes); std::vector<const Lepton::ExpressionTreeNode*>& nodes);
void findRelatedPowers(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode, void findRelatedPowers(const Lepton::ExpressionTreeNode& node, const Lepton::ExpressionTreeNode& searchNode,
std::map<int, const Lepton::ExpressionTreeNode*>& powers); std::map<int, const Lepton::ExpressionTreeNode*>& powers);
std::vector<std::vector<double> > computeFunctionParameters(const std::vector<const TabulatedFunction*>& functions); std::vector<std::vector<double> > computeFunctionParameters(const std::vector<const TabulatedFunction*>& functions);
OpenCLContext& context; OpenCLContext& context;
FunctionPlaceholder fp1, fp2, fp3; FunctionPlaceholder fp1, fp2, fp3, periodicDistance;
}; };
} // namespace OpenMM } // namespace OpenMM
......
...@@ -181,7 +181,7 @@ void OpenCLBondedUtilities::initialize(const System& system) { ...@@ -181,7 +181,7 @@ void OpenCLBondedUtilities::initialize(const System& system) {
for (int i = 0; i < (int) prefixCode.size(); i++) for (int i = 0; i < (int) prefixCode.size(); i++)
s<<prefixCode[i]; s<<prefixCode[i];
string bufferType = (context.getSupports64BitGlobalAtomics() ? "long" : "real4"); string bufferType = (context.getSupports64BitGlobalAtomics() ? "long" : "real4");
s<<"__kernel void computeBondedForces(__global "<<bufferType<<"* restrict forceBuffers, __global real* restrict energyBuffer, __global const real4* restrict posq, int groups"; s<<"__kernel void computeBondedForces(__global "<<bufferType<<"* restrict forceBuffers, __global real* restrict energyBuffer, __global const real4* restrict posq, int groups, real4 periodicBoxSize, real4 invPeriodicBoxSize, real4 periodicBoxVecX, real4 periodicBoxVecY, real4 periodicBoxVecZ";
for (int i = 0; i < setSize; i++) { for (int i = 0; i < setSize; i++) {
int force = set[i]; int force = set[i];
string indexType = "uint"+(indexWidth[force] == 1 ? "" : context.intToString(indexWidth[force])); string indexType = "uint"+(indexWidth[force] == 1 ? "" : context.intToString(indexWidth[force]));
...@@ -267,7 +267,7 @@ void OpenCLBondedUtilities::computeInteractions(int groups) { ...@@ -267,7 +267,7 @@ void OpenCLBondedUtilities::computeInteractions(int groups) {
kernel.setArg<cl::Buffer>(index++, context.getForceBuffers().getDeviceBuffer()); kernel.setArg<cl::Buffer>(index++, context.getForceBuffers().getDeviceBuffer());
kernel.setArg<cl::Buffer>(index++, context.getEnergyBuffer().getDeviceBuffer()); kernel.setArg<cl::Buffer>(index++, context.getEnergyBuffer().getDeviceBuffer());
kernel.setArg<cl::Buffer>(index++, context.getPosq().getDeviceBuffer()); kernel.setArg<cl::Buffer>(index++, context.getPosq().getDeviceBuffer());
index++; index += 6;
for (int j = 0; j < (int) forceSets[i].size(); j++) { for (int j = 0; j < (int) forceSets[i].size(); j++) {
kernel.setArg<cl::Buffer>(index++, atomIndices[forceSets[i][j]]->getDeviceBuffer()); kernel.setArg<cl::Buffer>(index++, atomIndices[forceSets[i][j]]->getDeviceBuffer());
kernel.setArg<cl::Buffer>(index++, bufferIndices[forceSets[i][j]]->getDeviceBuffer()); kernel.setArg<cl::Buffer>(index++, bufferIndices[forceSets[i][j]]->getDeviceBuffer());
...@@ -277,7 +277,22 @@ void OpenCLBondedUtilities::computeInteractions(int groups) { ...@@ -277,7 +277,22 @@ void OpenCLBondedUtilities::computeInteractions(int groups) {
} }
} }
for (int i = 0; i < (int) kernels.size(); i++) { for (int i = 0; i < (int) kernels.size(); i++) {
kernels[i].setArg<cl_int>(3, groups); cl::Kernel& kernel = kernels[i];
kernel.setArg<cl_int>(3, groups);
if (context.getUseDoublePrecision()) {
kernel.setArg<mm_double4>(4, context.getPeriodicBoxSizeDouble());
kernel.setArg<mm_double4>(5, context.getInvPeriodicBoxSizeDouble());
kernel.setArg<mm_double4>(6, context.getPeriodicBoxVecXDouble());
kernel.setArg<mm_double4>(7, context.getPeriodicBoxVecYDouble());
kernel.setArg<mm_double4>(8, context.getPeriodicBoxVecZDouble());
}
else {
kernel.setArg<mm_float4>(4, context.getPeriodicBoxSize());
kernel.setArg<mm_float4>(5, context.getInvPeriodicBoxSize());
kernel.setArg<mm_float4>(6, context.getPeriodicBoxVecX());
kernel.setArg<mm_float4>(7, context.getPeriodicBoxVecY());
kernel.setArg<mm_float4>(8, context.getPeriodicBoxVecZ());
}
context.executeKernel(kernels[i], maxBonds); context.executeKernel(kernels[i], maxBonds);
} }
} }
...@@ -3821,7 +3821,9 @@ void OpenCLCalcCustomExternalForceKernel::initialize(const System& system, const ...@@ -3821,7 +3821,9 @@ void OpenCLCalcCustomExternalForceKernel::initialize(const System& system, const
globalParamNames[i] = force.getGlobalParameterName(i); globalParamNames[i] = force.getGlobalParameterName(i);
globalParamValues[i] = (cl_float) force.getGlobalParameterDefaultValue(i); globalParamValues[i] = (cl_float) force.getGlobalParameterDefaultValue(i);
} }
Lepton::ParsedExpression energyExpression = Lepton::Parser::parse(force.getEnergyFunction()).optimize(); map<string, Lepton::CustomFunction*> customFunctions;
customFunctions["periodicdistance"] = cl.getExpressionUtilities().getPeriodicDistancePlaceholder();
Lepton::ParsedExpression energyExpression = Lepton::Parser::parse(force.getEnergyFunction(), customFunctions).optimize();
Lepton::ParsedExpression forceExpressionX = energyExpression.differentiate("x").optimize(); Lepton::ParsedExpression forceExpressionX = energyExpression.differentiate("x").optimize();
Lepton::ParsedExpression forceExpressionY = energyExpression.differentiate("y").optimize(); Lepton::ParsedExpression forceExpressionY = energyExpression.differentiate("y").optimize();
Lepton::ParsedExpression forceExpressionZ = energyExpression.differentiate("z").optimize(); Lepton::ParsedExpression forceExpressionZ = energyExpression.differentiate("z").optimize();
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
* Biological Structures at Stanford, funded under the NIH Roadmap for * * Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. * * Medical Research, grant U54 GM072970. See https://simtk.org. *
* * * *
* Portions copyright (c) 2008-2009 Stanford University and the Authors. * * Portions copyright (c) 2008-2015 Stanford University and the Authors. *
* Authors: Peter Eastman * * Authors: Peter Eastman *
* Contributors: * * Contributors: *
* * * *
...@@ -161,6 +161,47 @@ void testParallelComputation() { ...@@ -161,6 +161,47 @@ void testParallelComputation() {
ASSERT_EQUAL_VEC(state1.getForces()[i], state2.getForces()[i], 1e-5); ASSERT_EQUAL_VEC(state1.getForces()[i], state2.getForces()[i], 1e-5);
} }
void testPeriodic() {
Vec3 vx(5, 0, 0);
Vec3 vy(0, 6, 0);
Vec3 vz(1, 2, 7);
double x0 = 51, y0 = -17, z0 = 11.2;
System system;
system.setDefaultPeriodicBoxVectors(vx, vy, vz);
system.addParticle(1.0);
CustomExternalForce* force = new CustomExternalForce("periodicdistance(x, y, z, x0, y0, z0)^2");
force->addPerParticleParameter("x0");
force->addPerParticleParameter("y0");
force->addPerParticleParameter("z0");
vector<double> params(3);
params[0] = x0;
params[1] = y0;
params[2] = z0;
force->addParticle(0, params);
system.addForce(force);
VerletIntegrator integrator(0.01);
Context context(system, integrator, platform);
vector<Vec3> positions(1);
positions[0] = Vec3(0, 2, 0);
context.setPositions(positions);
for (int i = 0; i < 100; i++) {
State state = context.getState(State::Positions | State::Forces | State::Energy);
// Apply periodic boundary conditions to the difference between the two positions.
Vec3 delta = Vec3(x0, y0, z0)-state.getPositions()[0];
delta -= vz*floor(delta[2]/vz[2]+0.5);
delta -= vy*floor(delta[1]/vy[1]+0.5);
delta -= vx*floor(delta[0]/vx[0]+0.5);
// Verify that the force and energy are correct.
ASSERT_EQUAL_VEC(delta*2, state.getForces()[0], 1e-5);
ASSERT_EQUAL_TOL(delta.dot(delta), state.getPotentialEnergy(), 1e-5);
integrator.step(1);
}
}
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
try { try {
if (argc > 1) if (argc > 1)
...@@ -168,6 +209,7 @@ int main(int argc, char* argv[]) { ...@@ -168,6 +209,7 @@ int main(int argc, char* argv[]) {
testForce(); testForce();
testManyParameters(); testManyParameters();
testParallelComputation(); testParallelComputation();
testPeriodic();
} }
catch(const exception& e) { catch(const exception& e) {
cout << "exception: " << e.what() << endl; cout << "exception: " << e.what() << endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment