Commit 52ed82ec authored by Peter Eastman's avatar Peter Eastman
Browse files

Fixed bug using tabulated functions with interaction groups

parent 293d7c6b
......@@ -718,7 +718,7 @@ public:
*/
void copyParametersToContext(ContextImpl& context, const CustomNonbondedForce& force);
private:
void initInteractionGroups(const CustomNonbondedForce& force, const std::string& interactionSource);
void initInteractionGroups(const CustomNonbondedForce& force, const std::string& interactionSource, const std::vector<std::string>& tableTypes);
CudaContext& cu;
CudaParameterSet* params;
CudaArray* globals;
......
......@@ -2157,6 +2157,7 @@ void CudaCalcCustomNonbondedForceKernel::initialize(const System& system, const
map<string, Lepton::CustomFunction*> functions;
vector<pair<string, string> > functionDefinitions;
vector<const TabulatedFunction*> functionList;
vector<string> tableTypes;
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i);
......@@ -2168,6 +2169,10 @@ void CudaCalcCustomNonbondedForceKernel::initialize(const System& system, const
tabulatedFunctions.push_back(CudaArray::create<float>(cu, f.size(), "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
cu.getNonbondedUtilities().addArgument(CudaNonbondedUtilities::ParameterInfo(arrayName, "float", width, width*sizeof(float), tabulatedFunctions[tabulatedFunctions.size()-1]->getDevicePointer()));
if (width == 1)
tableTypes.push_back("float");
else
tableTypes.push_back("float"+cu.intToString(width));
}
// Record information for the expressions.
......@@ -2220,7 +2225,7 @@ void CudaCalcCustomNonbondedForceKernel::initialize(const System& system, const
}
string source = cu.replaceStrings(CudaKernelSources::customNonbonded, replacements);
if (force.getNumInteractionGroups() > 0)
initInteractionGroups(force, source);
initInteractionGroups(force, source, tableTypes);
else {
cu.getNonbondedUtilities().addInteraction(useCutoff, usePeriodic, true, force.getCutoffDistance(), exclusionList, source, force.getForceGroup());
for (int i = 0; i < (int) params->getBuffers().size(); i++) {
......@@ -2246,7 +2251,7 @@ void CudaCalcCustomNonbondedForceKernel::initialize(const System& system, const
}
}
void CudaCalcCustomNonbondedForceKernel::initInteractionGroups(const CustomNonbondedForce& force, const string& interactionSource) {
void CudaCalcCustomNonbondedForceKernel::initInteractionGroups(const CustomNonbondedForce& force, const string& interactionSource, const vector<string>& tableTypes) {
// Process groups to form tiles.
vector<vector<int> > atomLists;
......@@ -2431,6 +2436,8 @@ void CudaCalcCustomNonbondedForceKernel::initInteractionGroups(const CustomNonbo
args<<", const "<<buffers[i].getType()<<"* __restrict__ global_params"<<(i+1);
if (globals != NULL)
args<<", const float* __restrict__ globals";
for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
args << ", const " << tableTypes[i]<< "* __restrict__ table" << i;
replacements["PARAMETER_ARGUMENTS"] = args.str();
stringstream load1;
for (int i = 0; i < (int) buffers.size(); i++)
......@@ -2519,6 +2526,8 @@ double CudaCalcCustomNonbondedForceKernel::execute(ContextImpl& context, bool in
interactionGroupArgs.push_back(cu.getPeriodicBoxVecZPointer());
for (int i = 0; i < (int) params->getBuffers().size(); i++)
interactionGroupArgs.push_back(&params->getBuffers()[i].getMemory());
for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
interactionGroupArgs.push_back(&tabulatedFunctions[i]->getDevicePointer());
if (globals != NULL)
interactionGroupArgs.push_back(&globals->getDevicePointer());
}
......
......@@ -698,7 +698,7 @@ public:
*/
void copyParametersToContext(ContextImpl& context, const CustomNonbondedForce& force);
private:
void initInteractionGroups(const CustomNonbondedForce& force, const std::string& interactionSource);
void initInteractionGroups(const CustomNonbondedForce& force, const std::string& interactionSource, const std::vector<std::string>& tableTypes);
OpenCLContext& cl;
OpenCLParameterSet* params;
OpenCLArray* globals;
......
......@@ -2222,6 +2222,7 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
map<string, Lepton::CustomFunction*> functions;
vector<pair<string, string> > functionDefinitions;
vector<const TabulatedFunction*> functionList;
vector<string> tableTypes;
for (int i = 0; i < force.getNumFunctions(); i++) {
functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i);
......@@ -2233,6 +2234,10 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
tabulatedFunctions.push_back(OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(arrayName, "float", width, width*sizeof(float), tabulatedFunctions[tabulatedFunctions.size()-1]->getDeviceBuffer()));
if (width == 1)
tableTypes.push_back("float");
else
tableTypes.push_back("float"+cl.intToString(width));
}
// Record information for the expressions.
......@@ -2285,7 +2290,7 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
}
string source = cl.replaceStrings(OpenCLKernelSources::customNonbonded, replacements);
if (force.getNumInteractionGroups() > 0)
initInteractionGroups(force, source);
initInteractionGroups(force, source, tableTypes);
else {
cl.getNonbondedUtilities().addInteraction(useCutoff, usePeriodic, true, force.getCutoffDistance(), exclusionList, source, force.getForceGroup());
for (int i = 0; i < (int) params->getBuffers().size(); i++) {
......@@ -2311,7 +2316,7 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
}
}
void OpenCLCalcCustomNonbondedForceKernel::initInteractionGroups(const CustomNonbondedForce& force, const string& interactionSource) {
void OpenCLCalcCustomNonbondedForceKernel::initInteractionGroups(const CustomNonbondedForce& force, const string& interactionSource, const vector<string>& tableTypes) {
// Process groups to form tiles.
vector<vector<int> > atomLists;
......@@ -2504,6 +2509,8 @@ void OpenCLCalcCustomNonbondedForceKernel::initInteractionGroups(const CustomNon
args<<", __global const "<<buffers[i].getType()<<"* restrict global_params"<<(i+1);
if (globals != NULL)
args<<", __global const float* restrict globals";
for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
args << ", __global const " << tableTypes[i]<< "* restrict table" << i;
replacements["PARAMETER_ARGUMENTS"] = args.str();
stringstream load1;
for (int i = 0; i < (int) buffers.size(); i++)
......@@ -2591,6 +2598,8 @@ double OpenCLCalcCustomNonbondedForceKernel::execute(ContextImpl& context, bool
index += 5;
for (int i = 0; i < (int) params->getBuffers().size(); i++)
interactionGroupKernel.setArg<cl::Memory>(index++, params->getBuffers()[i].getMemory());
for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
interactionGroupKernel.setArg<cl::Memory>(index++, tabulatedFunctions[i]->getDeviceBuffer());
if (globals != NULL)
interactionGroupKernel.setArg<cl::Buffer>(index++, globals->getDeviceBuffer());
}
......
......@@ -937,6 +937,37 @@ void testInteractionGroupLongRangeCorrection() {
ASSERT_EQUAL_TOL(expected, energy2-energy1, 1e-4);
}
void testInteractionGroupTabulatedFunction() {
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomNonbondedForce* forceField = new CustomNonbondedForce("fn(r-1)+1");
set<int> set1, set2;
set1.insert(0);
set2.insert(1);
forceField->addInteractionGroup(set1, set2);
forceField->addParticle(vector<double>());
forceField->addParticle(vector<double>());
vector<double> table;
for (int i = 0; i < 21; i++)
table.push_back(sin(0.25*i));
forceField->addTabulatedFunction("fn", new Discrete1DFunction(table));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
for (int i = 0; i < (int) table.size(); i++) {
positions[1] = Vec3(i+1, 0, 0);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[0], 1e-6);
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[1], 1e-6);
ASSERT_EQUAL_TOL(table[i]+1.0, state.getPotentialEnergy(), 1e-6);
}
}
void testMultipleCutoffs() {
System system;
system.addParticle(1.0);
......@@ -1033,6 +1064,7 @@ int main(int argc, char* argv[]) {
testInteractionGroups();
testLargeInteractionGroup();
testInteractionGroupLongRangeCorrection();
testInteractionGroupTabulatedFunction();
testMultipleCutoffs();
testIllegalVariable();
runPlatformTests();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment