Commit 6ba9a156 authored by Peter Eastman's avatar Peter Eastman
Browse files

OpenCL version of CustomHbondForce is now complete (except for exclusions)

parent 68c89df6
......@@ -35,6 +35,7 @@
#include "ForceImpl.h"
#include "openmm/CustomHbondForce.h"
#include "openmm/Kernel.h"
#include "lepton/CustomFunction.h"
#include "lepton/ExpressionTreeNode.h"
#include "lepton/ParsedExpression.h"
#include <utility>
......@@ -68,15 +69,16 @@ public:
* as follows: 0=a1, 1=a2, 2=a3, 3=d1, 4=d2, 5=d3.
*
* @param force the CustomHbondForce to process
* @param distances on exist, this will contain an entry for each distance used in the expression. The key is the name
* @param functions definitions of custom function that may appear in the expression
* @param distances on exit, this will contain an entry for each distance used in the expression. The key is the name
* of the corresponding variable, and the value is the list of particle indices.
* @param angles on exist, this will contain an entry for each angle used in the expression. The key is the name
* @param angles on exit, this will contain an entry for each angle used in the expression. The key is the name
* of the corresponding variable, and the value is the list of particle indices.
* @param dihedrals on exist, this will contain an entry for each dihedral used in the expression. The key is the name
* @param dihedrals on exit, this will contain an entry for each dihedral used in the expression. The key is the name
* of the corresponding variable, and the value is the list of particle indices.
* @return a Parsed expression for the energy
*/
static Lepton::ParsedExpression prepareExpression(const CustomHbondForce& force, std::map<std::string, std::vector<int> >& distances,
static Lepton::ParsedExpression prepareExpression(const CustomHbondForce& force, const std::map<std::string, Lepton::CustomFunction*>& functions, std::map<std::string, std::vector<int> >& distances,
std::map<std::string, std::vector<int> >& angles, std::map<std::string, std::vector<int> >& dihedrals);
private:
class FunctionPlaceholder;
......
......@@ -199,24 +199,16 @@ map<string, double> CustomHbondForceImpl::getDefaultParameters() {
return parameters;
}
ParsedExpression CustomHbondForceImpl::prepareExpression(const CustomHbondForce& force, map<string, vector<int> >& distances,
ParsedExpression CustomHbondForceImpl::prepareExpression(const CustomHbondForce& force, const map<string, CustomFunction*>& customFunctions, map<string, vector<int> >& distances,
map<string, vector<int> >& angles, map<string, vector<int> >& dihedrals) {
CustomHbondForceImpl::FunctionPlaceholder custom(1);
CustomHbondForceImpl::FunctionPlaceholder distance(2);
CustomHbondForceImpl::FunctionPlaceholder angle(3);
CustomHbondForceImpl::FunctionPlaceholder dihedral(4);
map<string, CustomFunction*> functions;
map<string, CustomFunction*> functions = customFunctions;
functions["distance"] = &distance;
functions["angle"] = &angle;
functions["dihedral"] = &dihedral;
for (int i = 0; i < force.getNumFunctions(); i++) {
string name;
vector<double> values;
double min, max;
bool interpolating;
force.getFunctionParameters(i, name, values, min, max, interpolating);
functions[name] = &custom;
}
ParsedExpression expression = Lepton::Parser::parse(force.getEnergyFunction(), functions);
map<string, int> atoms;
atoms["a1"] = 0;
......
This diff is collapsed.
......@@ -710,7 +710,7 @@ private:
std::vector<cl_float> globalParamValues;
std::vector<OpenCLArray<mm_float4>*> tabulatedFunctions;
System& system;
cl::Kernel kernel;
cl::Kernel donorKernel, acceptorKernel;
};
/**
......
......@@ -25,7 +25,6 @@ float4 deltaPeriodic(float4 vec1, float4 vec2) {
/**
* Compute the angle between two vectors. The w component of each vector should contain the squared magnitude.
*/
float computeAngle(float4 vec1, float4 vec2) {
float dot = vec1.x*vec2.x + vec1.y*vec2.y + vec1.z*vec2.z;
float cosine = dot/sqrt(vec1.w*vec2.w);
......@@ -45,120 +44,165 @@ float computeAngle(float4 vec1, float4 vec2) {
}
/**
* Compute hbond interactions.
* Compute the cross product of two vectors, setting the fourth component to the squared magnitude.
*/
float4 computeCross(float4 vec1, float4 vec2) {
float4 result = cross(vec1, vec2);
result.w = result.x*result.x + result.y*result.y + result.z*result.z;
return result;
}
__kernel void computeHbonds(__global float4* forceBuffers, __global float* energyBuffer, __global float4* posq, /*__global unsigned int* exclusions,
__global unsigned int* exclusionIndices, */__global int4* donorAtoms, __global int4* acceptorAtoms, __global int4* donorBufferIndices, __global int4* acceptorBufferIndices, __local float4* posBuffer, __local float4* deltaBuffer
/**
* Compute forces on donors.
*/
__kernel void computeDonorForces(__global float4* forceBuffers, __global float* energyBuffer, __global float4* posq, /*__global unsigned int* exclusions,
__global unsigned int* exclusionIndices, */__global int4* donorAtoms, __global int4* acceptorAtoms, __global int4* donorBufferIndices, __local float4* posBuffer
PARAMETER_ARGUMENTS) {
float energy = 0.0f;
unsigned int tgx = get_local_id(0) & (get_local_size(0)-1);
unsigned int tbx = get_local_id(0) - tgx;
float4 f1 = 0;
float4 f2 = 0;
for (int donorIndex = get_global_id(0); donorIndex < NUM_DONORS; donorIndex += get_global_size(0)) {
float4 f3 = 0;
for (int donorStart = 0; donorStart < NUM_DONORS; donorStart += get_global_size(0)) {
// Load information about the donor this thread will compute forces on.
int4 atoms = donorAtoms[donorIndex];
float4 d1 = posq[atoms.x];
float4 d2 = posq[atoms.y];
float4 d3 = posq[atoms.z];
float4 deltaD1D2 = delta(d1, d2);
int donorIndex = donorStart+get_global_id(0);
int4 atoms;
float4 d1, d2, d3;
if (donorIndex < NUM_DONORS) {
atoms = donorAtoms[donorIndex];
d1 = posq[atoms.x];
d2 = posq[atoms.y];
d3 = posq[atoms.z];
}
else
atoms = (int4) (-1, -1, -1, -1);
for (int acceptorStart = 0; acceptorStart < NUM_ACCEPTORS; acceptorStart += get_local_size(0)) {
// Load the next block of acceptors into local memory.
int blockSize = min((int) get_local_size(0), NUM_ACCEPTORS-acceptorStart);
if (tgx < blockSize) {
int4 atoms2 = acceptorAtoms[acceptorStart+tgx];
float4 pos1 = posq[atoms2.x];
float4 pos2 = posq[atoms2.y];
float4 pos3 = posq[atoms2.z];
posBuffer[get_local_id(0)] = pos1;
deltaBuffer[get_local_id(0)] = delta(pos2, pos1);
if (get_local_id(0) < blockSize) {
int4 atoms2 = acceptorAtoms[acceptorStart+get_local_id(0)];
posBuffer[3*get_local_id(0)] = posq[atoms2.x];
posBuffer[3*get_local_id(0)+1] = posq[atoms2.y];
posBuffer[3*get_local_id(0)+2] = posq[atoms2.z];
}
barrier(CLK_LOCAL_MEM_FENCE);
for (int index = 0; index < blockSize; index++) {
// Compute the interaction between a donor and an acceptor.
float4 a1 = posBuffer[index];
float4 deltaD1A1 = deltaPeriodic(d1, a1);
if (donorIndex < NUM_DONORS) {
for (int index = 0; index < blockSize; index++) {
// Compute the interaction between a donor and an acceptor.
float4 a1 = posBuffer[3*index];
float4 a2 = posBuffer[3*index+1];
float4 a3 = posBuffer[3*index+2];
float4 deltaD1A1 = deltaPeriodic(d1, a1);
#ifdef USE_CUTOFF
if (deltaD1A1.w < CUTOFF_SQUARED) {
#endif
// Compute variables the force can depend on.
float r = sqrt(deltaD1A1.w);
float4 deltaA2A1 = deltaBuffer[index];
float theta = computeAngle(deltaD1A1, deltaD1D2);
float psi = computeAngle(deltaD1A1, deltaA2A1);
float4 cross1 = cross(deltaA2A1, deltaD1A1);
float4 cross2 = cross(deltaD1A1, deltaD1D2);
cross1.w = cross1.x*cross1.x + cross1.y*cross1.y + cross1.z*cross1.z;
cross2.w = cross2.x*cross2.x + cross2.y*cross2.y + cross2.z*cross2.z;
float chi = computeAngle(cross1, cross2);
chi = (dot(deltaA2A1, cross2) < 0 ? -chi : chi);
COMPUTE_FORCE
#ifdef INCLUDE_R
// Apply forces based on r.
f1.xyz -= (dEdR/r)*deltaD1A1.xyz;
if (deltaD1A1.w < CUTOFF_SQUARED) {
#endif
#ifdef INCLUDE_THETA
// Apply forces based on theta.
float4 thetaCross = cross(deltaD1D2, deltaD1A1);
float lengthThetaCross = max(length(thetaCross), 1e-6f);
float4 deltaCross0 = cross(deltaD1D2, thetaCross)*dEdTheta/(deltaD1D2.w*lengthThetaCross);
float4 deltaCross2 = -cross(deltaD1A1, thetaCross)*dEdTheta/(deltaD1A1.w*lengthThetaCross);
float4 deltaCross1 = -(deltaCross0+deltaCross2);
f1.xyz += deltaCross1.xyz;
f2.xyz += deltaCross0.xyz;
#endif
#ifdef INCLUDE_PSI
// Apply forces based on psi.
float4 psiCross = cross(deltaA2A1, deltaD1A1);
float lengthPsiCross = max(length(psiCross), 1e-6f);
deltaCross0 = cross(deltaD1A1, psiCross)*dEdPsi/(deltaD1A1.w*lengthPsiCross);
// float4 deltaCross2 = -cross(deltaA2A1, psiCross)*dEdPsi/(deltaA2A1.w*lengthPsiCross);
// float4 deltaCross1 = -(deltaCross0+deltaCross2);
f1.xyz += deltaCross0.xyz;
COMPUTE_DONOR_FORCE
#ifdef USE_CUTOFF
}
#endif
}
}
}
#ifdef INCLUDE_CHI
// Apply forces based on chi.
// Write results
float4 ff;
ff.x = (-dEdChi*r)/cross1.w;
ff.y = (deltaA2A1.x*deltaD1A1.x + deltaA2A1.y*deltaD1A1.y + deltaA2A1.z*deltaD1A1.z)/deltaD1A1.w;
ff.z = (deltaD1D2.x*deltaD1A1.x + deltaD1D2.y*deltaD1A1.y + deltaD1D2.z*deltaD1A1.z)/deltaD1A1.w;
ff.w = (dEdChi*r)/cross2.w;
float4 internalF0 = ff.x*cross1;
float4 internalF3 = ff.w*cross2;
float4 s = ff.y*internalF0 - ff.z*internalF3;
f1.xyz -= s.xyz+internalF3.xyz;
f2.xyz += internalF3.xyz;
int4 bufferIndices = donorBufferIndices[donorIndex];
if (atoms.x > -1) {
unsigned int offset = atoms.x+bufferIndices.x*PADDED_NUM_ATOMS;
float4 force = forceBuffers[offset];
force.xyz += f1.xyz;
forceBuffers[offset] = force;
}
if (atoms.y > -1) {
unsigned int offset = atoms.y+bufferIndices.y*PADDED_NUM_ATOMS;
float4 force = forceBuffers[offset];
force.xyz += f2.xyz;
forceBuffers[offset] = force;
}
if (atoms.z > -1) {
unsigned int offset = atoms.z+bufferIndices.z*PADDED_NUM_ATOMS;
float4 force = forceBuffers[offset];
force.xyz += f3.xyz;
forceBuffers[offset] = force;
}
}
energyBuffer[get_global_id(0)] += energy;
}
/**
* Compute forces on acceptors.
*/
__kernel void computeAcceptorForces(__global float4* forceBuffers, __global float* energyBuffer, __global float4* posq, /*__global unsigned int* exclusions,
__global unsigned int* exclusionIndices, */__global int4* donorAtoms, __global int4* acceptorAtoms, __global int4* acceptorBufferIndices, __local float4* posBuffer
PARAMETER_ARGUMENTS) {
float4 f1 = 0;
float4 f2 = 0;
float4 f3 = 0;
for (int acceptorStart = 0; acceptorStart < NUM_ACCEPTORS; acceptorStart += get_global_size(0)) {
// Load information about the acceptor this thread will compute forces on.
int acceptorIndex = acceptorStart+get_global_id(0);
int4 atoms;
float4 a1, a2, a3;
if (acceptorIndex < NUM_ACCEPTORS) {
atoms = acceptorAtoms[acceptorIndex];
a1 = posq[atoms.x];
a2 = posq[atoms.y];
a3 = posq[atoms.z];
}
else
atoms = (int4) (-1, -1, -1, -1);
for (int donorStart = 0; donorStart < NUM_DONORS; donorStart += get_local_size(0)) {
// Load the next block of donors into local memory.
int blockSize = min((int) get_local_size(0), NUM_DONORS-donorStart);
if (get_local_id(0) < blockSize) {
int4 atoms2 = donorAtoms[donorStart+get_local_id(0)];
posBuffer[3*get_local_id(0)] = posq[atoms2.x];
posBuffer[3*get_local_id(0)+1] = posq[atoms2.y];
posBuffer[3*get_local_id(0)+2] = posq[atoms2.z];
}
barrier(CLK_LOCAL_MEM_FENCE);
if (acceptorIndex < NUM_ACCEPTORS) {
for (int index = 0; index < blockSize; index++) {
// Compute the interaction between a donor and an acceptor.
float4 d1 = posBuffer[3*index];
float4 d2 = posBuffer[3*index+1];
float4 d3 = posBuffer[3*index+2];
float4 deltaD1A1 = deltaPeriodic(d1, a1);
#ifdef USE_CUTOFF
if (deltaD1A1.w < CUTOFF_SQUARED) {
#endif
COMPUTE_ACCEPTOR_FORCE
#ifdef USE_CUTOFF
}
}
#endif
}
}
}
// Write results
int4 bufferIndices = donorBufferIndices[donorIndex];
unsigned int offset1 = atoms.x+bufferIndices.x*PADDED_NUM_ATOMS;
unsigned int offset2 = atoms.y+bufferIndices.y*PADDED_NUM_ATOMS;
float4 force1 = forceBuffers[offset1];
float4 force2 = forceBuffers[offset2];
force1.xyz += f1.xyz;
force2.xyz += f2.xyz;
forceBuffers[offset1] = force1;
forceBuffers[offset2] = force2;
int4 bufferIndices = acceptorBufferIndices[acceptorIndex];
if (atoms.x > -1) {
unsigned int offset = atoms.x+bufferIndices.x*PADDED_NUM_ATOMS;
float4 force = forceBuffers[offset];
force.xyz += f1.xyz;
forceBuffers[offset] = force;
}
if (atoms.y > -1) {
unsigned int offset = atoms.y+bufferIndices.y*PADDED_NUM_ATOMS;
float4 force = forceBuffers[offset];
force.xyz += f2.xyz;
forceBuffers[offset] = force;
}
if (atoms.z > -1) {
unsigned int offset = atoms.z+bufferIndices.z*PADDED_NUM_ATOMS;
float4 force = forceBuffers[offset];
force.xyz += f3.xyz;
forceBuffers[offset] = force;
}
}
energyBuffer[get_global_id(0)] += energy;
}
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2008-2010 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
/**
* This tests the OpenCL implementation of CustomHbondForce.
*/
#include "../../../tests/AssertionUtilities.h"
#include "openmm/Context.h"
#include "OpenCLPlatform.h"
#include "openmm/CustomHbondForce.h"
#include "openmm/HarmonicAngleForce.h"
#include "openmm/HarmonicBondForce.h"
#include "openmm/PeriodicTorsionForce.h"
#include "openmm/System.h"
#include "openmm/VerletIntegrator.h"
#include "../src/sfmt/SFMT.h"
#include <iostream>
#include <vector>
using namespace OpenMM;
using namespace std;
const double TOL = 1e-5;
void testHbond() {
OpenCLPlatform platform;
// Create a system using a CustomHbondForce.
System customSystem;
customSystem.addParticle(1.0);
customSystem.addParticle(1.0);
customSystem.addParticle(1.0);
customSystem.addParticle(1.0);
customSystem.addParticle(1.0);
CustomHbondForce* custom = new CustomHbondForce("0.5*kr*(distance(d1,a1)-r0)^2 + 0.5*ktheta*(angle(a1,d1,d2)-theta0)^2 + 0.5*kpsi*(angle(d1,a1,a2)-psi0)^2 + kchi*(1+cos(n*dihedral(a3,a2,a1,d1)-chi0))");
custom->addPerDonorParameter("r0");
custom->addPerDonorParameter("theta0");
custom->addPerDonorParameter("psi0");
custom->addPerAcceptorParameter("chi0");
custom->addPerAcceptorParameter("n");
custom->addGlobalParameter("kr", 0.4);
custom->addGlobalParameter("ktheta", 0.5);
custom->addGlobalParameter("kpsi", 0.6);
custom->addGlobalParameter("kchi", 0.7);
vector<double> parameters(3);
parameters[0] = 1.5;
parameters[1] = 1.7;
parameters[2] = 1.9;
custom->addDonor(1, 0, -1, parameters);
parameters.resize(2);
parameters[0] = 2.1;
parameters[1] = 2;
custom->addAcceptor(2, 3, 4, parameters);
custom->setCutoffDistance(10.0);
customSystem.addForce(custom);
// Create an identical system using HarmonicBondForce, HarmonicAngleForce, and PeriodicTorsionForce.
System standardSystem;
standardSystem.addParticle(1.0);
standardSystem.addParticle(1.0);
standardSystem.addParticle(1.0);
standardSystem.addParticle(1.0);
standardSystem.addParticle(1.0);
HarmonicBondForce* bond = new HarmonicBondForce();
bond->addBond(1, 2, 1.5, 0.4);
standardSystem.addForce(bond);
HarmonicAngleForce* angle = new HarmonicAngleForce();
angle->addAngle(0, 1, 2, 1.7, 0.5);
angle->addAngle(1, 2, 3, 1.9, 0.6);
standardSystem.addForce(angle);
PeriodicTorsionForce* torsion = new PeriodicTorsionForce();
torsion->addTorsion(1, 2, 3, 4, 2, 2.1, 0.7);;
standardSystem.addForce(torsion);
// Set the atoms in various positions, and verify that both systems give identical forces and energy.
init_gen_rand(0);
vector<Vec3> positions(5);
VerletIntegrator integrator1(0.01);
VerletIntegrator integrator2(0.01);
for (int i = 0; i < 10; i++) {
Context c1(customSystem, integrator1, platform);
Context c2(standardSystem, integrator2, platform);
for (int j = 0; j < (int) positions.size(); j++)
positions[j] = Vec3(2.0*genrand_real2(), 2.0*genrand_real2(), 2.0*genrand_real2());
c1.setPositions(positions);
c2.setPositions(positions);
State s1 = c1.getState(State::Forces | State::Energy);
State s2 = c2.getState(State::Forces | State::Energy);
for (int i = 0; i < customSystem.getNumParticles(); i++)
ASSERT_EQUAL_VEC(s2.getForces()[i], s1.getForces()[i], TOL);
ASSERT_EQUAL_TOL(s2.getPotentialEnergy(), s1.getPotentialEnergy(), TOL);
}
}
void testExclusions() {
OpenCLPlatform platform;
System system;
system.addParticle(1.0);
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomHbondForce* custom = new CustomHbondForce("(distance(d1,a1)-1)^2");
custom->addDonor(0, 1, -1, vector<double>());
custom->addDonor(1, 0, -1, vector<double>());
custom->addAcceptor(2, 0, -1, vector<double>());
custom->addExclusion(1, 0);
system.addForce(custom);
Context context(system, integrator, platform);
vector<Vec3> positions(3);
positions[0] = Vec3(0, 0, 0);
positions[1] = Vec3(0, 2, 0);
positions[2] = Vec3(2, 0, 0);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
ASSERT_EQUAL_VEC(Vec3(2, 0, 0), forces[0], TOL);
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[1], TOL);
ASSERT_EQUAL_VEC(Vec3(-2, 0, 0), forces[2], TOL);
ASSERT_EQUAL_TOL(1.0, state.getPotentialEnergy(), TOL);
}
void testCutoff() {
OpenCLPlatform platform;
System system;
system.addParticle(1.0);
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomHbondForce* custom = new CustomHbondForce("(distance(d1,a1)-1)^2");
custom->addDonor(0, 1, -1, vector<double>());
custom->addDonor(1, 0, -1, vector<double>());
custom->addAcceptor(2, 0, -1, vector<double>());
custom->setNonbondedMethod(CustomHbondForce::CutoffNonPeriodic);
custom->setCutoffDistance(2.5);
system.addForce(custom);
Context context(system, integrator, platform);
vector<Vec3> positions(3);
positions[0] = Vec3(0, 0, 0);
positions[1] = Vec3(0, 3, 0);
positions[2] = Vec3(2, 0, 0);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
ASSERT_EQUAL_VEC(Vec3(2, 0, 0), forces[0], TOL);
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[1], TOL);
ASSERT_EQUAL_VEC(Vec3(-2, 0, 0), forces[2], TOL);
ASSERT_EQUAL_TOL(1.0, state.getPotentialEnergy(), TOL);
}
void testCustomFunctions() {
OpenCLPlatform platform;
System system;
system.addParticle(1.0);
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomHbondForce* custom = new CustomHbondForce("foo(distance(d1,a1))");
custom->addDonor(1, 0, -1, vector<double>());
custom->addDonor(2, 0, -1, vector<double>());
custom->addAcceptor(0, 1, -1, vector<double>());
vector<double> function(2);
function[0] = 0;
function[1] = 1;
custom->addFunction("foo", function, 0, 10, true);
system.addForce(custom);
Context context(system, integrator, platform);
vector<Vec3> positions(3);
positions[0] = Vec3(0, 0, 0);
positions[1] = Vec3(0, 2, 0);
positions[2] = Vec3(2, 0, 0);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
ASSERT_EQUAL_VEC(Vec3(0.1, 0.1, 0), forces[0], TOL);
ASSERT_EQUAL_VEC(Vec3(0, -0.1, 0), forces[1], TOL);
ASSERT_EQUAL_VEC(Vec3(-0.1, 0, 0), forces[2], TOL);
ASSERT_EQUAL_TOL(0.1*2+0.1*2, state.getPotentialEnergy(), TOL);
}
int main() {
try {
testHbond();
// testExclusions();
testCutoff();
testCustomFunctions();
}
catch(const exception& e) {
cout << "exception: " << e.what() << endl;
return 1;
}
cout << "Done" << endl;
return 0;
}
......@@ -1314,7 +1314,7 @@ void ReferenceCalcCustomHbondForceKernel::initialize(const System& system, const
map<string, vector<int> > distances;
map<string, vector<int> > angles;
map<string, vector<int> > dihedrals;
Lepton::ParsedExpression energyExpression = CustomHbondForceImpl::prepareExpression(force, distances, angles, dihedrals);
Lepton::ParsedExpression energyExpression = CustomHbondForceImpl::prepareExpression(force, functions, distances, angles, dihedrals);
vector<string> donorParameterNames;
vector<string> acceptorParameterNames;
for (int i = 0; i < numDonorParameters; i++)
......
......@@ -179,11 +179,42 @@ void testCutoff() {
ASSERT_EQUAL_TOL(1.0, state.getPotentialEnergy(), TOL);
}
void testCustomFunctions() {
ReferencePlatform platform;
System system;
system.addParticle(1.0);
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomHbondForce* custom = new CustomHbondForce("foo(distance(d1,a1))");
custom->addDonor(1, 0, -1, vector<double>());
custom->addDonor(2, 0, -1, vector<double>());
custom->addAcceptor(0, 1, -1, vector<double>());
vector<double> function(2);
function[0] = 0;
function[1] = 1;
custom->addFunction("foo", function, 0, 10, true);
system.addForce(custom);
Context context(system, integrator, platform);
vector<Vec3> positions(3);
positions[0] = Vec3(0, 0, 0);
positions[1] = Vec3(0, 2, 0);
positions[2] = Vec3(2, 0, 0);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
ASSERT_EQUAL_VEC(Vec3(0.1, 0.1, 0), forces[0], TOL);
ASSERT_EQUAL_VEC(Vec3(0, -0.1, 0), forces[1], TOL);
ASSERT_EQUAL_VEC(Vec3(-0.1, 0, 0), forces[2], TOL);
ASSERT_EQUAL_TOL(0.1*2+0.1*2, state.getPotentialEnergy(), TOL);
}
int main() {
try {
testHbond();
testExclusions();
testCutoff();
testCustomFunctions();
}
catch(const exception& e) {
cout << "exception: " << e.what() << endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment