Commit 19841146 authored by peastman's avatar peastman
Browse files

Bug fix

parent 94aa8c3f
...@@ -1400,27 +1400,32 @@ private: ...@@ -1400,27 +1400,32 @@ private:
class CudaCalcNonbondedForceKernel::SyncStreamPreComputation : public CudaContext::ForcePreComputation { class CudaCalcNonbondedForceKernel::SyncStreamPreComputation : public CudaContext::ForcePreComputation {
public: public:
SyncStreamPreComputation(CUstream stream, CUevent event) : stream(stream), event(event) { SyncStreamPreComputation(CUstream stream, CUevent event, int forceGroup) : stream(stream), event(event), forceGroup(forceGroup) {
} }
void computeForceAndEnergy(bool includeForces, bool includeEnergy, int groups) { void computeForceAndEnergy(bool includeForces, bool includeEnergy, int groups) {
cuEventRecord(event, 0); if ((groups&(1<<forceGroup)) != 0) {
cuStreamWaitEvent(stream, event, 0); cuEventRecord(event, 0);
cuStreamWaitEvent(stream, event, 0);
}
} }
private: private:
CUstream stream; CUstream stream;
CUevent event; CUevent event;
int forceGroup;
}; };
class CudaCalcNonbondedForceKernel::SyncStreamPostComputation : public CudaContext::ForcePostComputation { class CudaCalcNonbondedForceKernel::SyncStreamPostComputation : public CudaContext::ForcePostComputation {
public: public:
SyncStreamPostComputation(CUevent event) : event(event) { SyncStreamPostComputation(CUevent event, int forceGroup) : event(event), forceGroup(forceGroup) {
} }
double computeForceAndEnergy(bool includeForces, bool includeEnergy, int groups) { double computeForceAndEnergy(bool includeForces, bool includeEnergy, int groups) {
cuStreamWaitEvent(0, event, 0); if ((groups&(1<<forceGroup)) != 0)
cuStreamWaitEvent(0, event, 0);
return 0.0; return 0.0;
} }
private: private:
CUevent event; CUevent event;
int forceGroup;
}; };
CudaCalcNonbondedForceKernel::~CudaCalcNonbondedForceKernel() { CudaCalcNonbondedForceKernel::~CudaCalcNonbondedForceKernel() {
...@@ -1669,8 +1674,11 @@ void CudaCalcNonbondedForceKernel::initialize(const System& system, const Nonbon ...@@ -1669,8 +1674,11 @@ void CudaCalcNonbondedForceKernel::initialize(const System& system, const Nonbon
cufftSetStream(fftForward, pmeStream); cufftSetStream(fftForward, pmeStream);
cufftSetStream(fftBackward, pmeStream); cufftSetStream(fftBackward, pmeStream);
CHECK_RESULT(cuEventCreate(&pmeSyncEvent, CU_EVENT_DISABLE_TIMING), "Error creating event for NonbondedForce"); CHECK_RESULT(cuEventCreate(&pmeSyncEvent, CU_EVENT_DISABLE_TIMING), "Error creating event for NonbondedForce");
cu.addPreComputation(new SyncStreamPreComputation(pmeStream, pmeSyncEvent)); int recipForceGroup = force.getReciprocalSpaceForceGroup();
cu.addPostComputation(new SyncStreamPostComputation(pmeSyncEvent)); if (recipForceGroup < 0)
recipForceGroup = force.getForceGroup();
cu.addPreComputation(new SyncStreamPreComputation(pmeStream, pmeSyncEvent, recipForceGroup));
cu.addPostComputation(new SyncStreamPostComputation(pmeSyncEvent, recipForceGroup));
hasInitializedFFT = true; hasInitializedFFT = true;
// Initialize the b-spline moduli. // Initialize the b-spline moduli.
......
...@@ -1386,31 +1386,37 @@ private: ...@@ -1386,31 +1386,37 @@ private:
class OpenCLCalcNonbondedForceKernel::SyncQueuePreComputation : public OpenCLContext::ForcePreComputation { class OpenCLCalcNonbondedForceKernel::SyncQueuePreComputation : public OpenCLContext::ForcePreComputation {
public: public:
SyncQueuePreComputation(OpenCLContext& cl, cl::CommandQueue queue) : cl(cl), queue(queue), events(1) { SyncQueuePreComputation(OpenCLContext& cl, cl::CommandQueue queue, int forceGroup) : cl(cl), queue(queue), events(1), forceGroup(forceGroup) {
} }
void computeForceAndEnergy(bool includeForces, bool includeEnergy, int groups) { void computeForceAndEnergy(bool includeForces, bool includeEnergy, int groups) {
cl.getQueue().enqueueMarker(&events[0]); if ((groups&(1<<forceGroup)) != 0) {
queue.enqueueWaitForEvents(events); cl.getQueue().enqueueMarker(&events[0]);
queue.enqueueWaitForEvents(events);
}
} }
private: private:
OpenCLContext& cl; OpenCLContext& cl;
cl::CommandQueue queue; cl::CommandQueue queue;
vector<cl::Event> events; vector<cl::Event> events;
int forceGroup;
}; };
class OpenCLCalcNonbondedForceKernel::SyncQueuePostComputation : public OpenCLContext::ForcePostComputation { class OpenCLCalcNonbondedForceKernel::SyncQueuePostComputation : public OpenCLContext::ForcePostComputation {
public: public:
SyncQueuePostComputation(OpenCLContext& cl, cl::Event& event) : cl(cl), event(event), events(1) { SyncQueuePostComputation(OpenCLContext& cl, cl::Event& event, int forceGroup) : cl(cl), event(event), events(1), forceGroup(forceGroup) {
} }
double computeForceAndEnergy(bool includeForces, bool includeEnergy, int groups) { double computeForceAndEnergy(bool includeForces, bool includeEnergy, int groups) {
events[0] = event; if ((groups&(1<<forceGroup)) != 0) {
cl.getQueue().enqueueWaitForEvents(events); events[0] = event;
cl.getQueue().enqueueWaitForEvents(events);
}
return 0.0; return 0.0;
} }
private: private:
OpenCLContext& cl; OpenCLContext& cl;
cl::Event& event; cl::Event& event;
vector<cl::Event> events; vector<cl::Event> events;
int forceGroup;
}; };
OpenCLCalcNonbondedForceKernel::~OpenCLCalcNonbondedForceKernel() { OpenCLCalcNonbondedForceKernel::~OpenCLCalcNonbondedForceKernel() {
...@@ -1604,8 +1610,11 @@ void OpenCLCalcNonbondedForceKernel::initialize(const System& system, const Nonb ...@@ -1604,8 +1610,11 @@ void OpenCLCalcNonbondedForceKernel::initialize(const System& system, const Nonb
sort = new OpenCLSort(cl, new SortTrait(), cl.getNumAtoms()); sort = new OpenCLSort(cl, new SortTrait(), cl.getNumAtoms());
fft = new OpenCLFFT3D(cl, gridSizeX, gridSizeY, gridSizeZ); fft = new OpenCLFFT3D(cl, gridSizeX, gridSizeY, gridSizeZ);
pmeQueue = cl::CommandQueue(cl.getContext(), cl.getDevice()); pmeQueue = cl::CommandQueue(cl.getContext(), cl.getDevice());
cl.addPreComputation(new SyncQueuePreComputation(cl, pmeQueue)); int recipForceGroup = force.getReciprocalSpaceForceGroup();
cl.addPostComputation(new SyncQueuePostComputation(cl, pmeSyncEvent)); if (recipForceGroup < 0)
recipForceGroup = force.getForceGroup();
cl.addPreComputation(new SyncQueuePreComputation(cl, pmeQueue, recipForceGroup));
cl.addPostComputation(new SyncQueuePostComputation(cl, pmeSyncEvent, recipForceGroup));
// Initialize the b-spline moduli. // Initialize the b-spline moduli.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment