Commit 19841146 authored by peastman's avatar peastman
Browse files

Bug fix

parent 94aa8c3f
......@@ -1400,27 +1400,32 @@ private:
class CudaCalcNonbondedForceKernel::SyncStreamPreComputation : public CudaContext::ForcePreComputation {
public:
SyncStreamPreComputation(CUstream stream, CUevent event) : stream(stream), event(event) {
SyncStreamPreComputation(CUstream stream, CUevent event, int forceGroup) : stream(stream), event(event), forceGroup(forceGroup) {
}
void computeForceAndEnergy(bool includeForces, bool includeEnergy, int groups) {
if ((groups&(1<<forceGroup)) != 0) {
cuEventRecord(event, 0);
cuStreamWaitEvent(stream, event, 0);
}
}
private:
CUstream stream;
CUevent event;
int forceGroup;
};
class CudaCalcNonbondedForceKernel::SyncStreamPostComputation : public CudaContext::ForcePostComputation {
public:
SyncStreamPostComputation(CUevent event) : event(event) {
SyncStreamPostComputation(CUevent event, int forceGroup) : event(event), forceGroup(forceGroup) {
}
double computeForceAndEnergy(bool includeForces, bool includeEnergy, int groups) {
if ((groups&(1<<forceGroup)) != 0)
cuStreamWaitEvent(0, event, 0);
return 0.0;
}
private:
CUevent event;
int forceGroup;
};
CudaCalcNonbondedForceKernel::~CudaCalcNonbondedForceKernel() {
......@@ -1669,8 +1674,11 @@ void CudaCalcNonbondedForceKernel::initialize(const System& system, const Nonbon
cufftSetStream(fftForward, pmeStream);
cufftSetStream(fftBackward, pmeStream);
CHECK_RESULT(cuEventCreate(&pmeSyncEvent, CU_EVENT_DISABLE_TIMING), "Error creating event for NonbondedForce");
cu.addPreComputation(new SyncStreamPreComputation(pmeStream, pmeSyncEvent));
cu.addPostComputation(new SyncStreamPostComputation(pmeSyncEvent));
int recipForceGroup = force.getReciprocalSpaceForceGroup();
if (recipForceGroup < 0)
recipForceGroup = force.getForceGroup();
cu.addPreComputation(new SyncStreamPreComputation(pmeStream, pmeSyncEvent, recipForceGroup));
cu.addPostComputation(new SyncStreamPostComputation(pmeSyncEvent, recipForceGroup));
hasInitializedFFT = true;
// Initialize the b-spline moduli.
......
......@@ -1386,31 +1386,37 @@ private:
class OpenCLCalcNonbondedForceKernel::SyncQueuePreComputation : public OpenCLContext::ForcePreComputation {
public:
SyncQueuePreComputation(OpenCLContext& cl, cl::CommandQueue queue) : cl(cl), queue(queue), events(1) {
SyncQueuePreComputation(OpenCLContext& cl, cl::CommandQueue queue, int forceGroup) : cl(cl), queue(queue), events(1), forceGroup(forceGroup) {
}
void computeForceAndEnergy(bool includeForces, bool includeEnergy, int groups) {
if ((groups&(1<<forceGroup)) != 0) {
cl.getQueue().enqueueMarker(&events[0]);
queue.enqueueWaitForEvents(events);
}
}
private:
OpenCLContext& cl;
cl::CommandQueue queue;
vector<cl::Event> events;
int forceGroup;
};
class OpenCLCalcNonbondedForceKernel::SyncQueuePostComputation : public OpenCLContext::ForcePostComputation {
public:
SyncQueuePostComputation(OpenCLContext& cl, cl::Event& event) : cl(cl), event(event), events(1) {
SyncQueuePostComputation(OpenCLContext& cl, cl::Event& event, int forceGroup) : cl(cl), event(event), events(1), forceGroup(forceGroup) {
}
double computeForceAndEnergy(bool includeForces, bool includeEnergy, int groups) {
if ((groups&(1<<forceGroup)) != 0) {
events[0] = event;
cl.getQueue().enqueueWaitForEvents(events);
}
return 0.0;
}
private:
OpenCLContext& cl;
cl::Event& event;
vector<cl::Event> events;
int forceGroup;
};
OpenCLCalcNonbondedForceKernel::~OpenCLCalcNonbondedForceKernel() {
......@@ -1604,8 +1610,11 @@ void OpenCLCalcNonbondedForceKernel::initialize(const System& system, const Nonb
sort = new OpenCLSort(cl, new SortTrait(), cl.getNumAtoms());
fft = new OpenCLFFT3D(cl, gridSizeX, gridSizeY, gridSizeZ);
pmeQueue = cl::CommandQueue(cl.getContext(), cl.getDevice());
cl.addPreComputation(new SyncQueuePreComputation(cl, pmeQueue));
cl.addPostComputation(new SyncQueuePostComputation(cl, pmeSyncEvent));
int recipForceGroup = force.getReciprocalSpaceForceGroup();
if (recipForceGroup < 0)
recipForceGroup = force.getForceGroup();
cl.addPreComputation(new SyncQueuePreComputation(cl, pmeQueue, recipForceGroup));
cl.addPostComputation(new SyncQueuePostComputation(cl, pmeSyncEvent, recipForceGroup));
// Initialize the b-spline moduli.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment