Unverified Commit 5e913400 authored by Stephen Farr's avatar Stephen Farr Committed by GitHub
Browse files

another attempted fix for MonteCarloBarostat (#4119)

* save atomIndex order in MCBarostat kernel

* Update device indices

* Only setAtomIndex if there was a reorder
parent 7f41337f
...@@ -1576,7 +1576,7 @@ public: ...@@ -1576,7 +1576,7 @@ public:
void restoreCoordinates(ContextImpl& context); void restoreCoordinates(ContextImpl& context);
private: private:
ComputeContext& cc; ComputeContext& cc;
bool hasInitializedKernels, rigidMolecules; bool hasInitializedKernels, rigidMolecules, atomsWereReordered;
int numMolecules; int numMolecules;
ComputeArray savedPositions, savedFloatForces, savedLongForces; ComputeArray savedPositions, savedFloatForces, savedLongForces;
ComputeArray moleculeAtoms; ComputeArray moleculeAtoms;
......
...@@ -387,6 +387,10 @@ public: ...@@ -387,6 +387,10 @@ public:
const std::vector<int>& getAtomIndex() const { const std::vector<int>& getAtomIndex() const {
return atomIndex; return atomIndex;
} }
/**
* Set the vector which contains the index of each atom.
*/
void setAtomIndex(std::vector<int>& index);
/** /**
* Get the array which contains the index of each atom. * Get the array which contains the index of each atom.
*/ */
......
...@@ -7720,10 +7720,15 @@ void CommonApplyMonteCarloBarostatKernel::saveCoordinates(ContextImpl& context) ...@@ -7720,10 +7720,15 @@ void CommonApplyMonteCarloBarostatKernel::saveCoordinates(ContextImpl& context)
if (savedFloatForces.isInitialized()) if (savedFloatForces.isInitialized())
cc.getFloatForceBuffer().copyTo(savedFloatForces); cc.getFloatForceBuffer().copyTo(savedFloatForces);
lastPosCellOffsets = cc.getPosCellOffsets(); lastPosCellOffsets = cc.getPosCellOffsets();
lastAtomOrder = cc.getAtomIndex();
} }
void CommonApplyMonteCarloBarostatKernel::scaleCoordinates(ContextImpl& context, double scaleX, double scaleY, double scaleZ) { void CommonApplyMonteCarloBarostatKernel::scaleCoordinates(ContextImpl& context, double scaleX, double scaleY, double scaleZ) {
ContextSelector selector(cc); ContextSelector selector(cc);
// check if atoms were reordered from energy evaluation before scaling
atomsWereReordered = cc.getAtomsWereReordered();
if (!hasInitializedKernels) { if (!hasInitializedKernels) {
hasInitializedKernels = true; hasInitializedKernels = true;
...@@ -7769,7 +7774,6 @@ void CommonApplyMonteCarloBarostatKernel::scaleCoordinates(ContextImpl& context, ...@@ -7769,7 +7774,6 @@ void CommonApplyMonteCarloBarostatKernel::scaleCoordinates(ContextImpl& context,
kernel->setArg(2, (float) scaleZ); kernel->setArg(2, (float) scaleZ);
setPeriodicBoxArgs(cc, kernel, 4); setPeriodicBoxArgs(cc, kernel, 4);
kernel->execute(cc.getNumAtoms()); kernel->execute(cc.getNumAtoms());
lastAtomOrder = cc.getAtomIndex();
} }
void CommonApplyMonteCarloBarostatKernel::restoreCoordinates(ContextImpl& context) { void CommonApplyMonteCarloBarostatKernel::restoreCoordinates(ContextImpl& context) {
...@@ -7779,4 +7783,8 @@ void CommonApplyMonteCarloBarostatKernel::restoreCoordinates(ContextImpl& contex ...@@ -7779,4 +7783,8 @@ void CommonApplyMonteCarloBarostatKernel::restoreCoordinates(ContextImpl& contex
cc.setPosCellOffsets(lastPosCellOffsets); cc.setPosCellOffsets(lastPosCellOffsets);
if (savedFloatForces.isInitialized()) if (savedFloatForces.isInitialized())
savedFloatForces.copyTo(cc.getFloatForceBuffer()); savedFloatForces.copyTo(cc.getFloatForceBuffer());
// check if atoms were reordered from energy evaluation before or after scaling
if (atomsWereReordered || cc.getAtomsWereReordered())
cc.setAtomIndex(lastAtomOrder);
} }
\ No newline at end of file
...@@ -57,6 +57,13 @@ void ComputeContext::addForce(ComputeForceInfo* force) { ...@@ -57,6 +57,13 @@ void ComputeContext::addForce(ComputeForceInfo* force) {
forces.push_back(force); forces.push_back(force);
} }
void ComputeContext::setAtomIndex(std::vector<int>& index){
atomIndex = index;
getAtomIndexArray().upload(atomIndex);
for (auto listener : reorderListeners)
listener->execute();
}
string ComputeContext::replaceStrings(const string& input, const std::map<std::string, std::string>& replacements) const { string ComputeContext::replaceStrings(const string& input, const std::map<std::string, std::string>& replacements) const {
static set<char> symbolChars; static set<char> symbolChars;
if (symbolChars.size() == 0) { if (symbolChars.size() == 0) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment