Commit 65b56ee1 authored by peastman's avatar peastman
Browse files

Implemented type filters for CUDA version of CustomManyParticleForce

parent 785592f4
...@@ -931,7 +931,7 @@ private: ...@@ -931,7 +931,7 @@ private:
class CudaCalcCustomManyParticleForceKernel : public CalcCustomManyParticleForceKernel { class CudaCalcCustomManyParticleForceKernel : public CalcCustomManyParticleForceKernel {
public: public:
CudaCalcCustomManyParticleForceKernel(std::string name, const Platform& platform, CudaContext& cu, const System& system) : CalcCustomManyParticleForceKernel(name, platform), CudaCalcCustomManyParticleForceKernel(std::string name, const Platform& platform, CudaContext& cu, const System& system) : CalcCustomManyParticleForceKernel(name, platform),
hasInitializedKernel(false), cu(cu), params(NULL), globals(NULL), particleTypes(NULL), system(system) { hasInitializedKernel(false), cu(cu), params(NULL), globals(NULL), particleTypes(NULL), orderIndex(NULL), particleOrder(NULL), system(system) {
} }
~CudaCalcCustomManyParticleForceKernel(); ~CudaCalcCustomManyParticleForceKernel();
/** /**
...@@ -965,6 +965,8 @@ private: ...@@ -965,6 +965,8 @@ private:
CudaParameterSet* params; CudaParameterSet* params;
CudaArray* globals; CudaArray* globals;
CudaArray* particleTypes; CudaArray* particleTypes;
CudaArray* orderIndex;
CudaArray* particleOrder;
std::vector<std::string> globalParamNames; std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues; std::vector<float> globalParamValues;
std::vector<CudaArray*> tabulatedFunctions; std::vector<CudaArray*> tabulatedFunctions;
......
...@@ -4430,6 +4430,10 @@ CudaCalcCustomManyParticleForceKernel::~CudaCalcCustomManyParticleForceKernel() ...@@ -4430,6 +4430,10 @@ CudaCalcCustomManyParticleForceKernel::~CudaCalcCustomManyParticleForceKernel()
delete params; delete params;
if (globals != NULL) if (globals != NULL)
delete globals; delete globals;
if (orderIndex != NULL)
delete orderIndex;
if (particleOrder != NULL)
delete particleOrder;
if (particleTypes != NULL) if (particleTypes != NULL)
delete particleTypes; delete particleTypes;
for (int i = 0; i < (int) tabulatedFunctions.size(); i++) for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
...@@ -4512,6 +4516,27 @@ void CudaCalcCustomManyParticleForceKernel::initialize(const System& system, con ...@@ -4512,6 +4516,27 @@ void CudaCalcCustomManyParticleForceKernel::initialize(const System& system, con
variables.push_back(makeVariable(name, value)); variables.push_back(makeVariable(name, value));
} }
} }
// Build data structures for type filters.
vector<int> particleTypesVec;
vector<int> orderIndexVec;
vector<std::vector<int> > particleOrderVec;
int numTypes;
CustomManyParticleForceImpl::buildFilterArrays(force, numTypes, particleTypesVec, orderIndexVec, particleOrderVec);
bool hasTypeFilters = (particleOrderVec.size() > 1);
if (hasTypeFilters) {
particleTypes = CudaArray::create<int>(cu, particleTypesVec.size(), "customManyParticleTypes");
orderIndex = CudaArray::create<int>(cu, orderIndexVec.size(), "customManyParticleOrderIndex");
particleOrder = CudaArray::create<int>(cu, particleOrderVec.size()*particlesPerSet, "customManyParticleOrder");
particleTypes->upload(particleTypesVec);
orderIndex->upload(orderIndexVec);
vector<int> flattenedOrder(particleOrder->getSize());
for (int i = 0; i < (int) particleOrderVec.size(); i++)
for (int j = 0; j < particlesPerSet; j++)
flattenedOrder[i*particlesPerSet+j] = particleOrderVec[i][j];
particleOrder->upload(flattenedOrder);
}
// Now to generate the kernel. First, it needs to calculate all distances, angles, // Now to generate the kernel. First, it needs to calculate all distances, angles,
// and dihedrals the expression depends on. // and dihedrals the expression depends on.
...@@ -4677,8 +4702,20 @@ void CudaCalcCustomManyParticleForceKernel::initialize(const System& system, con ...@@ -4677,8 +4702,20 @@ void CudaCalcCustomManyParticleForceKernel::initialize(const System& system, con
// Create other replacements that depend on the number of particles per set. // Create other replacements that depend on the number of particles per set.
stringstream numCombinations, atomsForCombination, isValidCombination, permute, loadData, verifyCutoff; stringstream numCombinations, atomsForCombination, isValidCombination, permute, loadData, verifyCutoff;
if (hasTypeFilters) {
permute<<"int particleSet[] = {";
for (int i = 0; i < particlesPerSet; i++) {
permute<<"p"<<(i+1);
if (i < particlesPerSet-1)
permute<<", ";
}
permute<<"};\n";
}
for (int i = 0; i < particlesPerSet; i++) { for (int i = 0; i < particlesPerSet; i++) {
permute<<"int atom"<<(i+1)<<" = p"<<(i+1)<<";\n"; if (hasTypeFilters)
permute<<"int atom"<<(i+1)<<" = particleSet[particleOrder["<<numTypes<<"*order+"<<i<<"]];\n";
else
permute<<"int atom"<<(i+1)<<" = p"<<(i+1)<<";\n";
loadData<<"real4 pos"<<(i+1)<<" = posq[atom"<<(i+1)<<"];\n"; loadData<<"real4 pos"<<(i+1)<<" = posq[atom"<<(i+1)<<"];\n";
for (int j = 0; j < (int) params->getBuffers().size(); j++) for (int j = 0; j < (int) params->getBuffers().size(); j++)
loadData<<params->getBuffers()[j].getType()<<" params"<<(j+1)<<(i+1)<<" = global_params"<<(j+1)<<"[atom"<<(i+1)<<"];\n"; loadData<<params->getBuffers()[j].getType()<<" params"<<(j+1)<<(i+1)<<" = global_params"<<(j+1)<<"[atom"<<(i+1)<<"];\n";
...@@ -4704,6 +4741,9 @@ void CudaCalcCustomManyParticleForceKernel::initialize(const System& system, con ...@@ -4704,6 +4741,9 @@ void CudaCalcCustomManyParticleForceKernel::initialize(const System& system, con
for (int j = i+1; j < particlesPerSet; j++) for (int j = i+1; j < particlesPerSet; j++)
verifyCutoff<<"includeInteraction &= (delta(pos"<<(i+1)<<", pos"<<(j+1)<<", periodicBoxSize, invPeriodicBoxSize).w < CUTOFF_SQUARED);\n"; verifyCutoff<<"includeInteraction &= (delta(pos"<<(i+1)<<", pos"<<(j+1)<<", periodicBoxSize, invPeriodicBoxSize).w < CUTOFF_SQUARED);\n";
} }
string computeTypeIndex = "particleTypes[p"+cu.intToString(particlesPerSet)+"]";
for (int i = particlesPerSet-2; i >= 0; i--)
computeTypeIndex = "particleTypes[p"+cu.intToString(i+1)+"]+"+cu.intToString(numTypes)+"*("+computeTypeIndex+")";
// Create replacements for extra arguments. // Create replacements for extra arguments.
...@@ -4725,12 +4765,15 @@ void CudaCalcCustomManyParticleForceKernel::initialize(const System& system, con ...@@ -4725,12 +4765,15 @@ void CudaCalcCustomManyParticleForceKernel::initialize(const System& system, con
replacements["VERIFY_CUTOFF"] = verifyCutoff.str(); replacements["VERIFY_CUTOFF"] = verifyCutoff.str();
replacements["PERMUTE_ATOMS"] = permute.str(); replacements["PERMUTE_ATOMS"] = permute.str();
replacements["LOAD_PARTICLE_DATA"] = loadData.str(); replacements["LOAD_PARTICLE_DATA"] = loadData.str();
replacements["COMPUTE_TYPE_INDEX"] = computeTypeIndex;
replacements["PARAMETER_ARGUMENTS"] = extraArgs.str()+tableArgs.str(); replacements["PARAMETER_ARGUMENTS"] = extraArgs.str()+tableArgs.str();
map<string, string> defines; map<string, string> defines;
if (nonbondedMethod != NoCutoff) if (nonbondedMethod != NoCutoff)
defines["USE_CUTOFF"] = "1"; defines["USE_CUTOFF"] = "1";
if (nonbondedMethod == CutoffPeriodic) if (nonbondedMethod == CutoffPeriodic)
defines["USE_PERIODIC"] = "1"; defines["USE_PERIODIC"] = "1";
if (hasTypeFilters)
defines["USE_FILTERS"] = "1";
defines["NUM_ATOMS"] = cu.intToString(cu.getNumAtoms()); defines["NUM_ATOMS"] = cu.intToString(cu.getNumAtoms());
defines["PADDED_NUM_ATOMS"] = cu.intToString(cu.getPaddedNumAtoms()); defines["PADDED_NUM_ATOMS"] = cu.intToString(cu.getPaddedNumAtoms());
defines["M_PI"] = cu.doubleToString(M_PI); defines["M_PI"] = cu.doubleToString(M_PI);
...@@ -4749,6 +4792,11 @@ double CudaCalcCustomManyParticleForceKernel::execute(ContextImpl& context, bool ...@@ -4749,6 +4792,11 @@ double CudaCalcCustomManyParticleForceKernel::execute(ContextImpl& context, bool
forceArgs.push_back(&cu.getPosq().getDevicePointer()); forceArgs.push_back(&cu.getPosq().getDevicePointer());
forceArgs.push_back(cu.getPeriodicBoxSizePointer()); forceArgs.push_back(cu.getPeriodicBoxSizePointer());
forceArgs.push_back(cu.getInvPeriodicBoxSizePointer()); forceArgs.push_back(cu.getInvPeriodicBoxSizePointer());
if (particleTypes != NULL) {
forceArgs.push_back(&particleTypes->getDevicePointer());
forceArgs.push_back(&orderIndex->getDevicePointer());
forceArgs.push_back(&particleOrder->getDevicePointer());
}
if (globals != NULL) if (globals != NULL)
forceArgs.push_back(&globals->getDevicePointer()); forceArgs.push_back(&globals->getDevicePointer());
for (int i = 0; i < (int) params->getBuffers().size(); i++) { for (int i = 0; i < (int) params->getBuffers().size(); i++) {
......
...@@ -64,6 +64,9 @@ inline __device__ real4 computeCross(real4 vec1, real4 vec2) { ...@@ -64,6 +64,9 @@ inline __device__ real4 computeCross(real4 vec1, real4 vec2) {
extern "C" __global__ void computeInteraction( extern "C" __global__ void computeInteraction(
unsigned long long* __restrict__ forceBuffers, real* __restrict__ energyBuffer, const real4* __restrict__ posq, unsigned long long* __restrict__ forceBuffers, real* __restrict__ energyBuffer, const real4* __restrict__ posq,
real4 periodicBoxSize, real4 invPeriodicBoxSize real4 periodicBoxSize, real4 invPeriodicBoxSize
#ifdef USE_FILTERS
, int* __restrict__ particleTypes, int* __restrict__ orderIndex, int* __restrict__ particleOrder
#endif
PARAMETER_ARGUMENTS) { PARAMETER_ARGUMENTS) {
real energy = 0.0f; real energy = 0.0f;
...@@ -79,6 +82,11 @@ extern "C" __global__ void computeInteraction( ...@@ -79,6 +82,11 @@ extern "C" __global__ void computeInteraction(
if (includeInteraction) { if (includeInteraction) {
VERIFY_CUTOFF; VERIFY_CUTOFF;
} }
#endif
#ifdef USE_FILTERS
int order = orderIndex[COMPUTE_TYPE_INDEX];
if (order == -1)
includeInteraction = false;
#endif #endif
if (includeInteraction) { if (includeInteraction) {
PERMUTE_ATOMS; PERMUTE_ATOMS;
......
...@@ -469,10 +469,10 @@ int main(int argc, char* argv[]) { ...@@ -469,10 +469,10 @@ int main(int argc, char* argv[]) {
testCutoff(); testCutoff();
testPeriodic(); testPeriodic();
// testExclusions(); // testExclusions();
// testAllTerms(); testAllTerms();
testParameters(); testParameters();
testTabulatedFunctions(); testTabulatedFunctions();
// testTypeFilters(); testTypeFilters();
} }
catch(const exception& e) { catch(const exception& e) {
cout << "exception: " << e.what() << endl; cout << "exception: " << e.what() << endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment