Commit 375a7081 authored by peastman's avatar peastman
Browse files

Merge pull request #1442 from peastman/tables

Fixed bug using tabulated functions with interaction groups
parents f2e879e0 52ed82ec
...@@ -718,7 +718,7 @@ public: ...@@ -718,7 +718,7 @@ public:
*/ */
void copyParametersToContext(ContextImpl& context, const CustomNonbondedForce& force); void copyParametersToContext(ContextImpl& context, const CustomNonbondedForce& force);
private: private:
void initInteractionGroups(const CustomNonbondedForce& force, const std::string& interactionSource); void initInteractionGroups(const CustomNonbondedForce& force, const std::string& interactionSource, const std::vector<std::string>& tableTypes);
CudaContext& cu; CudaContext& cu;
CudaParameterSet* params; CudaParameterSet* params;
CudaArray* globals; CudaArray* globals;
......
...@@ -2157,6 +2157,7 @@ void CudaCalcCustomNonbondedForceKernel::initialize(const System& system, const ...@@ -2157,6 +2157,7 @@ void CudaCalcCustomNonbondedForceKernel::initialize(const System& system, const
map<string, Lepton::CustomFunction*> functions; map<string, Lepton::CustomFunction*> functions;
vector<pair<string, string> > functionDefinitions; vector<pair<string, string> > functionDefinitions;
vector<const TabulatedFunction*> functionList; vector<const TabulatedFunction*> functionList;
vector<string> tableTypes;
for (int i = 0; i < force.getNumTabulatedFunctions(); i++) { for (int i = 0; i < force.getNumTabulatedFunctions(); i++) {
functionList.push_back(&force.getTabulatedFunction(i)); functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i); string name = force.getTabulatedFunctionName(i);
...@@ -2168,6 +2169,10 @@ void CudaCalcCustomNonbondedForceKernel::initialize(const System& system, const ...@@ -2168,6 +2169,10 @@ void CudaCalcCustomNonbondedForceKernel::initialize(const System& system, const
tabulatedFunctions.push_back(CudaArray::create<float>(cu, f.size(), "TabulatedFunction")); tabulatedFunctions.push_back(CudaArray::create<float>(cu, f.size(), "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f); tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
cu.getNonbondedUtilities().addArgument(CudaNonbondedUtilities::ParameterInfo(arrayName, "float", width, width*sizeof(float), tabulatedFunctions[tabulatedFunctions.size()-1]->getDevicePointer())); cu.getNonbondedUtilities().addArgument(CudaNonbondedUtilities::ParameterInfo(arrayName, "float", width, width*sizeof(float), tabulatedFunctions[tabulatedFunctions.size()-1]->getDevicePointer()));
if (width == 1)
tableTypes.push_back("float");
else
tableTypes.push_back("float"+cu.intToString(width));
} }
// Record information for the expressions. // Record information for the expressions.
...@@ -2220,7 +2225,7 @@ void CudaCalcCustomNonbondedForceKernel::initialize(const System& system, const ...@@ -2220,7 +2225,7 @@ void CudaCalcCustomNonbondedForceKernel::initialize(const System& system, const
} }
string source = cu.replaceStrings(CudaKernelSources::customNonbonded, replacements); string source = cu.replaceStrings(CudaKernelSources::customNonbonded, replacements);
if (force.getNumInteractionGroups() > 0) if (force.getNumInteractionGroups() > 0)
initInteractionGroups(force, source); initInteractionGroups(force, source, tableTypes);
else { else {
cu.getNonbondedUtilities().addInteraction(useCutoff, usePeriodic, true, force.getCutoffDistance(), exclusionList, source, force.getForceGroup()); cu.getNonbondedUtilities().addInteraction(useCutoff, usePeriodic, true, force.getCutoffDistance(), exclusionList, source, force.getForceGroup());
for (int i = 0; i < (int) params->getBuffers().size(); i++) { for (int i = 0; i < (int) params->getBuffers().size(); i++) {
...@@ -2246,7 +2251,7 @@ void CudaCalcCustomNonbondedForceKernel::initialize(const System& system, const ...@@ -2246,7 +2251,7 @@ void CudaCalcCustomNonbondedForceKernel::initialize(const System& system, const
} }
} }
void CudaCalcCustomNonbondedForceKernel::initInteractionGroups(const CustomNonbondedForce& force, const string& interactionSource) { void CudaCalcCustomNonbondedForceKernel::initInteractionGroups(const CustomNonbondedForce& force, const string& interactionSource, const vector<string>& tableTypes) {
// Process groups to form tiles. // Process groups to form tiles.
vector<vector<int> > atomLists; vector<vector<int> > atomLists;
...@@ -2431,6 +2436,8 @@ void CudaCalcCustomNonbondedForceKernel::initInteractionGroups(const CustomNonbo ...@@ -2431,6 +2436,8 @@ void CudaCalcCustomNonbondedForceKernel::initInteractionGroups(const CustomNonbo
args<<", const "<<buffers[i].getType()<<"* __restrict__ global_params"<<(i+1); args<<", const "<<buffers[i].getType()<<"* __restrict__ global_params"<<(i+1);
if (globals != NULL) if (globals != NULL)
args<<", const float* __restrict__ globals"; args<<", const float* __restrict__ globals";
for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
args << ", const " << tableTypes[i]<< "* __restrict__ table" << i;
replacements["PARAMETER_ARGUMENTS"] = args.str(); replacements["PARAMETER_ARGUMENTS"] = args.str();
stringstream load1; stringstream load1;
for (int i = 0; i < (int) buffers.size(); i++) for (int i = 0; i < (int) buffers.size(); i++)
...@@ -2519,6 +2526,8 @@ double CudaCalcCustomNonbondedForceKernel::execute(ContextImpl& context, bool in ...@@ -2519,6 +2526,8 @@ double CudaCalcCustomNonbondedForceKernel::execute(ContextImpl& context, bool in
interactionGroupArgs.push_back(cu.getPeriodicBoxVecZPointer()); interactionGroupArgs.push_back(cu.getPeriodicBoxVecZPointer());
for (int i = 0; i < (int) params->getBuffers().size(); i++) for (int i = 0; i < (int) params->getBuffers().size(); i++)
interactionGroupArgs.push_back(&params->getBuffers()[i].getMemory()); interactionGroupArgs.push_back(&params->getBuffers()[i].getMemory());
for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
interactionGroupArgs.push_back(&tabulatedFunctions[i]->getDevicePointer());
if (globals != NULL) if (globals != NULL)
interactionGroupArgs.push_back(&globals->getDevicePointer()); interactionGroupArgs.push_back(&globals->getDevicePointer());
} }
......
...@@ -698,7 +698,7 @@ public: ...@@ -698,7 +698,7 @@ public:
*/ */
void copyParametersToContext(ContextImpl& context, const CustomNonbondedForce& force); void copyParametersToContext(ContextImpl& context, const CustomNonbondedForce& force);
private: private:
void initInteractionGroups(const CustomNonbondedForce& force, const std::string& interactionSource); void initInteractionGroups(const CustomNonbondedForce& force, const std::string& interactionSource, const std::vector<std::string>& tableTypes);
OpenCLContext& cl; OpenCLContext& cl;
OpenCLParameterSet* params; OpenCLParameterSet* params;
OpenCLArray* globals; OpenCLArray* globals;
......
...@@ -2222,6 +2222,7 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons ...@@ -2222,6 +2222,7 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
map<string, Lepton::CustomFunction*> functions; map<string, Lepton::CustomFunction*> functions;
vector<pair<string, string> > functionDefinitions; vector<pair<string, string> > functionDefinitions;
vector<const TabulatedFunction*> functionList; vector<const TabulatedFunction*> functionList;
vector<string> tableTypes;
for (int i = 0; i < force.getNumFunctions(); i++) { for (int i = 0; i < force.getNumFunctions(); i++) {
functionList.push_back(&force.getTabulatedFunction(i)); functionList.push_back(&force.getTabulatedFunction(i));
string name = force.getTabulatedFunctionName(i); string name = force.getTabulatedFunctionName(i);
...@@ -2233,6 +2234,10 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons ...@@ -2233,6 +2234,10 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
tabulatedFunctions.push_back(OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction")); tabulatedFunctions.push_back(OpenCLArray::create<float>(cl, f.size(), "TabulatedFunction"));
tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f); tabulatedFunctions[tabulatedFunctions.size()-1]->upload(f);
cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(arrayName, "float", width, width*sizeof(float), tabulatedFunctions[tabulatedFunctions.size()-1]->getDeviceBuffer())); cl.getNonbondedUtilities().addArgument(OpenCLNonbondedUtilities::ParameterInfo(arrayName, "float", width, width*sizeof(float), tabulatedFunctions[tabulatedFunctions.size()-1]->getDeviceBuffer()));
if (width == 1)
tableTypes.push_back("float");
else
tableTypes.push_back("float"+cl.intToString(width));
} }
// Record information for the expressions. // Record information for the expressions.
...@@ -2285,7 +2290,7 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons ...@@ -2285,7 +2290,7 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
} }
string source = cl.replaceStrings(OpenCLKernelSources::customNonbonded, replacements); string source = cl.replaceStrings(OpenCLKernelSources::customNonbonded, replacements);
if (force.getNumInteractionGroups() > 0) if (force.getNumInteractionGroups() > 0)
initInteractionGroups(force, source); initInteractionGroups(force, source, tableTypes);
else { else {
cl.getNonbondedUtilities().addInteraction(useCutoff, usePeriodic, true, force.getCutoffDistance(), exclusionList, source, force.getForceGroup()); cl.getNonbondedUtilities().addInteraction(useCutoff, usePeriodic, true, force.getCutoffDistance(), exclusionList, source, force.getForceGroup());
for (int i = 0; i < (int) params->getBuffers().size(); i++) { for (int i = 0; i < (int) params->getBuffers().size(); i++) {
...@@ -2311,7 +2316,7 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons ...@@ -2311,7 +2316,7 @@ void OpenCLCalcCustomNonbondedForceKernel::initialize(const System& system, cons
} }
} }
void OpenCLCalcCustomNonbondedForceKernel::initInteractionGroups(const CustomNonbondedForce& force, const string& interactionSource) { void OpenCLCalcCustomNonbondedForceKernel::initInteractionGroups(const CustomNonbondedForce& force, const string& interactionSource, const vector<string>& tableTypes) {
// Process groups to form tiles. // Process groups to form tiles.
vector<vector<int> > atomLists; vector<vector<int> > atomLists;
...@@ -2504,6 +2509,8 @@ void OpenCLCalcCustomNonbondedForceKernel::initInteractionGroups(const CustomNon ...@@ -2504,6 +2509,8 @@ void OpenCLCalcCustomNonbondedForceKernel::initInteractionGroups(const CustomNon
args<<", __global const "<<buffers[i].getType()<<"* restrict global_params"<<(i+1); args<<", __global const "<<buffers[i].getType()<<"* restrict global_params"<<(i+1);
if (globals != NULL) if (globals != NULL)
args<<", __global const float* restrict globals"; args<<", __global const float* restrict globals";
for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
args << ", __global const " << tableTypes[i]<< "* restrict table" << i;
replacements["PARAMETER_ARGUMENTS"] = args.str(); replacements["PARAMETER_ARGUMENTS"] = args.str();
stringstream load1; stringstream load1;
for (int i = 0; i < (int) buffers.size(); i++) for (int i = 0; i < (int) buffers.size(); i++)
...@@ -2591,6 +2598,8 @@ double OpenCLCalcCustomNonbondedForceKernel::execute(ContextImpl& context, bool ...@@ -2591,6 +2598,8 @@ double OpenCLCalcCustomNonbondedForceKernel::execute(ContextImpl& context, bool
index += 5; index += 5;
for (int i = 0; i < (int) params->getBuffers().size(); i++) for (int i = 0; i < (int) params->getBuffers().size(); i++)
interactionGroupKernel.setArg<cl::Memory>(index++, params->getBuffers()[i].getMemory()); interactionGroupKernel.setArg<cl::Memory>(index++, params->getBuffers()[i].getMemory());
for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
interactionGroupKernel.setArg<cl::Memory>(index++, tabulatedFunctions[i]->getDeviceBuffer());
if (globals != NULL) if (globals != NULL)
interactionGroupKernel.setArg<cl::Buffer>(index++, globals->getDeviceBuffer()); interactionGroupKernel.setArg<cl::Buffer>(index++, globals->getDeviceBuffer());
} }
......
...@@ -937,6 +937,37 @@ void testInteractionGroupLongRangeCorrection() { ...@@ -937,6 +937,37 @@ void testInteractionGroupLongRangeCorrection() {
ASSERT_EQUAL_TOL(expected, energy2-energy1, 1e-4); ASSERT_EQUAL_TOL(expected, energy2-energy1, 1e-4);
} }
void testInteractionGroupTabulatedFunction() {
System system;
system.addParticle(1.0);
system.addParticle(1.0);
VerletIntegrator integrator(0.01);
CustomNonbondedForce* forceField = new CustomNonbondedForce("fn(r-1)+1");
set<int> set1, set2;
set1.insert(0);
set2.insert(1);
forceField->addInteractionGroup(set1, set2);
forceField->addParticle(vector<double>());
forceField->addParticle(vector<double>());
vector<double> table;
for (int i = 0; i < 21; i++)
table.push_back(sin(0.25*i));
forceField->addTabulatedFunction("fn", new Discrete1DFunction(table));
system.addForce(forceField);
Context context(system, integrator, platform);
vector<Vec3> positions(2);
positions[0] = Vec3(0, 0, 0);
for (int i = 0; i < (int) table.size(); i++) {
positions[1] = Vec3(i+1, 0, 0);
context.setPositions(positions);
State state = context.getState(State::Forces | State::Energy);
const vector<Vec3>& forces = state.getForces();
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[0], 1e-6);
ASSERT_EQUAL_VEC(Vec3(0, 0, 0), forces[1], 1e-6);
ASSERT_EQUAL_TOL(table[i]+1.0, state.getPotentialEnergy(), 1e-6);
}
}
void testMultipleCutoffs() { void testMultipleCutoffs() {
System system; System system;
system.addParticle(1.0); system.addParticle(1.0);
...@@ -1033,6 +1064,7 @@ int main(int argc, char* argv[]) { ...@@ -1033,6 +1064,7 @@ int main(int argc, char* argv[]) {
testInteractionGroups(); testInteractionGroups();
testLargeInteractionGroup(); testLargeInteractionGroup();
testInteractionGroupLongRangeCorrection(); testInteractionGroupLongRangeCorrection();
testInteractionGroupTabulatedFunction();
testMultipleCutoffs(); testMultipleCutoffs();
testIllegalVariable(); testIllegalVariable();
runPlatformTests(); runPlatformTests();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment