"plugins/vscode:/vscode.git/clone" did not exist on "10b51d251a193936ce11893d5fd014dde3d1965d"
Commit 65b56ee1 authored by peastman's avatar peastman
Browse files

Implemented type filters for CUDA version of CustomManyParticleForce

parent 785592f4
......@@ -931,7 +931,7 @@ private:
class CudaCalcCustomManyParticleForceKernel : public CalcCustomManyParticleForceKernel {
public:
CudaCalcCustomManyParticleForceKernel(std::string name, const Platform& platform, CudaContext& cu, const System& system) : CalcCustomManyParticleForceKernel(name, platform),
hasInitializedKernel(false), cu(cu), params(NULL), globals(NULL), particleTypes(NULL), system(system) {
hasInitializedKernel(false), cu(cu), params(NULL), globals(NULL), particleTypes(NULL), orderIndex(NULL), particleOrder(NULL), system(system) {
}
~CudaCalcCustomManyParticleForceKernel();
/**
......@@ -965,6 +965,8 @@ private:
CudaParameterSet* params;
CudaArray* globals;
CudaArray* particleTypes;
CudaArray* orderIndex;
CudaArray* particleOrder;
std::vector<std::string> globalParamNames;
std::vector<float> globalParamValues;
std::vector<CudaArray*> tabulatedFunctions;
......
......@@ -4430,6 +4430,10 @@ CudaCalcCustomManyParticleForceKernel::~CudaCalcCustomManyParticleForceKernel()
delete params;
if (globals != NULL)
delete globals;
if (orderIndex != NULL)
delete orderIndex;
if (particleOrder != NULL)
delete particleOrder;
if (particleTypes != NULL)
delete particleTypes;
for (int i = 0; i < (int) tabulatedFunctions.size(); i++)
......@@ -4513,6 +4517,27 @@ void CudaCalcCustomManyParticleForceKernel::initialize(const System& system, con
}
}
// Build data structures for type filters.
vector<int> particleTypesVec;
vector<int> orderIndexVec;
vector<std::vector<int> > particleOrderVec;
int numTypes;
CustomManyParticleForceImpl::buildFilterArrays(force, numTypes, particleTypesVec, orderIndexVec, particleOrderVec);
bool hasTypeFilters = (particleOrderVec.size() > 1);
if (hasTypeFilters) {
particleTypes = CudaArray::create<int>(cu, particleTypesVec.size(), "customManyParticleTypes");
orderIndex = CudaArray::create<int>(cu, orderIndexVec.size(), "customManyParticleOrderIndex");
particleOrder = CudaArray::create<int>(cu, particleOrderVec.size()*particlesPerSet, "customManyParticleOrder");
particleTypes->upload(particleTypesVec);
orderIndex->upload(orderIndexVec);
vector<int> flattenedOrder(particleOrder->getSize());
for (int i = 0; i < (int) particleOrderVec.size(); i++)
for (int j = 0; j < particlesPerSet; j++)
flattenedOrder[i*particlesPerSet+j] = particleOrderVec[i][j];
particleOrder->upload(flattenedOrder);
}
// Now to generate the kernel. First, it needs to calculate all distances, angles,
// and dihedrals the expression depends on.
......@@ -4677,7 +4702,19 @@ void CudaCalcCustomManyParticleForceKernel::initialize(const System& system, con
// Create other replacements that depend on the number of particles per set.
stringstream numCombinations, atomsForCombination, isValidCombination, permute, loadData, verifyCutoff;
if (hasTypeFilters) {
permute<<"int particleSet[] = {";
for (int i = 0; i < particlesPerSet; i++) {
permute<<"p"<<(i+1);
if (i < particlesPerSet-1)
permute<<", ";
}
permute<<"};\n";
}
for (int i = 0; i < particlesPerSet; i++) {
if (hasTypeFilters)
permute<<"int atom"<<(i+1)<<" = particleSet[particleOrder["<<numTypes<<"*order+"<<i<<"]];\n";
else
permute<<"int atom"<<(i+1)<<" = p"<<(i+1)<<";\n";
loadData<<"real4 pos"<<(i+1)<<" = posq[atom"<<(i+1)<<"];\n";
for (int j = 0; j < (int) params->getBuffers().size(); j++)
......@@ -4704,6 +4741,9 @@ void CudaCalcCustomManyParticleForceKernel::initialize(const System& system, con
for (int j = i+1; j < particlesPerSet; j++)
verifyCutoff<<"includeInteraction &= (delta(pos"<<(i+1)<<", pos"<<(j+1)<<", periodicBoxSize, invPeriodicBoxSize).w < CUTOFF_SQUARED);\n";
}
string computeTypeIndex = "particleTypes[p"+cu.intToString(particlesPerSet)+"]";
for (int i = particlesPerSet-2; i >= 0; i--)
computeTypeIndex = "particleTypes[p"+cu.intToString(i+1)+"]+"+cu.intToString(numTypes)+"*("+computeTypeIndex+")";
// Create replacements for extra arguments.
......@@ -4725,12 +4765,15 @@ void CudaCalcCustomManyParticleForceKernel::initialize(const System& system, con
replacements["VERIFY_CUTOFF"] = verifyCutoff.str();
replacements["PERMUTE_ATOMS"] = permute.str();
replacements["LOAD_PARTICLE_DATA"] = loadData.str();
replacements["COMPUTE_TYPE_INDEX"] = computeTypeIndex;
replacements["PARAMETER_ARGUMENTS"] = extraArgs.str()+tableArgs.str();
map<string, string> defines;
if (nonbondedMethod != NoCutoff)
defines["USE_CUTOFF"] = "1";
if (nonbondedMethod == CutoffPeriodic)
defines["USE_PERIODIC"] = "1";
if (hasTypeFilters)
defines["USE_FILTERS"] = "1";
defines["NUM_ATOMS"] = cu.intToString(cu.getNumAtoms());
defines["PADDED_NUM_ATOMS"] = cu.intToString(cu.getPaddedNumAtoms());
defines["M_PI"] = cu.doubleToString(M_PI);
......@@ -4749,6 +4792,11 @@ double CudaCalcCustomManyParticleForceKernel::execute(ContextImpl& context, bool
forceArgs.push_back(&cu.getPosq().getDevicePointer());
forceArgs.push_back(cu.getPeriodicBoxSizePointer());
forceArgs.push_back(cu.getInvPeriodicBoxSizePointer());
if (particleTypes != NULL) {
forceArgs.push_back(&particleTypes->getDevicePointer());
forceArgs.push_back(&orderIndex->getDevicePointer());
forceArgs.push_back(&particleOrder->getDevicePointer());
}
if (globals != NULL)
forceArgs.push_back(&globals->getDevicePointer());
for (int i = 0; i < (int) params->getBuffers().size(); i++) {
......
......@@ -64,6 +64,9 @@ inline __device__ real4 computeCross(real4 vec1, real4 vec2) {
extern "C" __global__ void computeInteraction(
unsigned long long* __restrict__ forceBuffers, real* __restrict__ energyBuffer, const real4* __restrict__ posq,
real4 periodicBoxSize, real4 invPeriodicBoxSize
#ifdef USE_FILTERS
, int* __restrict__ particleTypes, int* __restrict__ orderIndex, int* __restrict__ particleOrder
#endif
PARAMETER_ARGUMENTS) {
real energy = 0.0f;
......@@ -79,6 +82,11 @@ extern "C" __global__ void computeInteraction(
if (includeInteraction) {
VERIFY_CUTOFF;
}
#endif
#ifdef USE_FILTERS
int order = orderIndex[COMPUTE_TYPE_INDEX];
if (order == -1)
includeInteraction = false;
#endif
if (includeInteraction) {
PERMUTE_ATOMS;
......
......@@ -469,10 +469,10 @@ int main(int argc, char* argv[]) {
testCutoff();
testPeriodic();
// testExclusions();
// testAllTerms();
testAllTerms();
testParameters();
testTabulatedFunctions();
// testTypeFilters();
testTypeFilters();
}
catch(const exception& e) {
cout << "exception: " << e.what() << endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment