Unverified Commit d47ea1de authored by Peter Eastman's avatar Peter Eastman Committed by GitHub
Browse files

Unified storage of global parameters (#5002)

* Unified storage of global parameters

* Fixes to CUDA and HIP

* Store global parameters as real instead of float
parent 6c119bc2
...@@ -158,8 +158,8 @@ public: ...@@ -158,8 +158,8 @@ public:
/** /**
* Get the ContextImpl is ComputeContext is associated with. * Get the ContextImpl is ComputeContext is associated with.
*/ */
ContextImpl& getContextImpl() { ContextImpl* getContextImpl() {
return *platformData.context; return platformData.context;
} }
/** /**
* Get a workspace used for accumulating energy when a simulation is parallelized across * Get a workspace used for accumulating energy when a simulation is parallelized across
......
...@@ -69,7 +69,6 @@ class HipContext; ...@@ -69,7 +69,6 @@ class HipContext;
class OPENMM_EXPORT_COMMON HipNonbondedUtilities : public NonbondedUtilities { class OPENMM_EXPORT_COMMON HipNonbondedUtilities : public NonbondedUtilities {
public: public:
class ParameterInfo;
HipNonbondedUtilities(HipContext& context); HipNonbondedUtilities(HipContext& context);
~HipNonbondedUtilities(); ~HipNonbondedUtilities();
/** /**
...@@ -93,22 +92,10 @@ public: ...@@ -93,22 +92,10 @@ public:
* Add a per-atom parameter that the default interaction kernel may depend on. * Add a per-atom parameter that the default interaction kernel may depend on.
*/ */
void addParameter(ComputeParameterInfo parameter); void addParameter(ComputeParameterInfo parameter);
/**
* Add a per-atom parameter that the default interaction kernel may depend on.
*
* @deprecated Use the version that takes a ComputeParameterInfo instead.
*/
void addParameter(const ParameterInfo& parameter);
/** /**
* Add an array (other than a per-atom parameter) that should be passed as an argument to the default interaction kernel. * Add an array (other than a per-atom parameter) that should be passed as an argument to the default interaction kernel.
*/ */
void addArgument(ComputeParameterInfo parameter); void addArgument(ComputeParameterInfo parameter);
/**
* Add an array (other than a per-atom parameter) that should be passed as an argument to the default interaction kernel.
*
* @deprecated Use the version that takes a ComputeParameterInfo instead.
*/
void addArgument(const ParameterInfo& parameter);
/** /**
* Register that the interaction kernel will be computing the derivative of the potential energy * Register that the interaction kernel will be computing the derivative of the potential energy
* with respect to a parameter. * with respect to a parameter.
...@@ -296,7 +283,7 @@ public: ...@@ -296,7 +283,7 @@ public:
* @param includeForces whether this kernel should compute forces * @param includeForces whether this kernel should compute forces
* @param includeEnergy whether this kernel should compute potential energy * @param includeEnergy whether this kernel should compute potential energy
*/ */
hipFunction_t createInteractionKernel(const std::string& source, std::vector<ParameterInfo>& params, std::vector<ParameterInfo>& arguments, bool useExclusions, bool isSymmetric, int groups, bool includeForces, bool includeEnergy); hipFunction_t createInteractionKernel(const std::string& source, std::vector<ComputeParameterInfo>& params, std::vector<ComputeParameterInfo>& arguments, bool useExclusions, bool isSymmetric, int groups, bool includeForces, bool includeEnergy);
/** /**
* Create the set of kernels that will be needed for a particular combination of force groups. * Create the set of kernels that will be needed for a particular combination of force groups.
* *
...@@ -311,6 +298,7 @@ public: ...@@ -311,6 +298,7 @@ public:
private: private:
class KernelSet; class KernelSet;
class BlockSortTrait; class BlockSortTrait;
void initParamArgs();
HipContext& context; HipContext& context;
std::map<int, KernelSet> groupKernels; std::map<int, KernelSet> groupKernels;
HipArray exclusionTiles; HipArray exclusionTiles;
...@@ -337,15 +325,15 @@ private: ...@@ -337,15 +325,15 @@ private:
unsigned int* pinnedCountBuffer; unsigned int* pinnedCountBuffer;
std::vector<void*> forceArgs, findBlockBoundsArgs, computeSortKeysArgs, sortBoxDataArgs, findInteractingBlocksArgs, copyInteractionCountsArgs; std::vector<void*> forceArgs, findBlockBoundsArgs, computeSortKeysArgs, sortBoxDataArgs, findInteractingBlocksArgs, copyInteractionCountsArgs;
std::vector<std::vector<int> > atomExclusions; std::vector<std::vector<int> > atomExclusions;
std::vector<ParameterInfo> parameters; std::vector<ComputeParameterInfo> parameters;
std::vector<ParameterInfo> arguments; std::vector<ComputeParameterInfo> arguments;
std::vector<std::string> energyParameterDerivatives; std::vector<std::string> energyParameterDerivatives;
std::map<int, double> groupCutoff; std::map<int, double> groupCutoff;
std::map<int, std::string> groupKernelSource; std::map<int, std::string> groupKernelSource;
double maxCutoff; double maxCutoff;
bool useCutoff, usePeriodic, anyExclusions, usePadding, useNeighborList, forceRebuildNeighborList, canUsePairList, useLargeBlocks; bool useCutoff, usePeriodic, anyExclusions, usePadding, useNeighborList, forceRebuildNeighborList, canUsePairList, useLargeBlocks, hasInitializedParams;
int startTileIndex, startBlockIndex, numBlocks, numTilesInBatch, maxExclusions; int startTileIndex, startBlockIndex, numBlocks, numTilesInBatch, maxExclusions;
int numForceThreadBlocks, forceThreadBlockSize, findInteractingBlocksThreadBlockSize, numAtoms, groupFlags; int numForceThreadBlocks, forceThreadBlockSize, findInteractingBlocksThreadBlockSize, numAtoms, groupFlags, paramStartIndex;
unsigned int maxTiles, maxSinglePairs, tilesAfterReorder; unsigned int maxTiles, maxSinglePairs, tilesAfterReorder;
long long numTiles; long long numTiles;
std::string kernelSource; std::string kernelSource;
...@@ -367,62 +355,6 @@ public: ...@@ -367,62 +355,6 @@ public:
hipFunction_t copyInteractionCountsKernel; hipFunction_t copyInteractionCountsKernel;
}; };
/**
* This class stores information about a per-atom parameter that may be used in a nonbonded kernel.
*/
class HipNonbondedUtilities::ParameterInfo {
public:
/**
* Create a ParameterInfo object.
*
* @param name the name of the parameter
* @param type the data type of the parameter's components
* @param numComponents the number of components in the parameter
* @param size the size of the parameter in bytes
* @param memory the memory containing the parameter values
* @param constant whether the memory should be marked as constant
*/
ParameterInfo(const std::string& name, const std::string& componentType, int numComponents, int size, hipDeviceptr_t memory, bool constant=true) :
name(name), componentType(componentType), numComponents(numComponents), size(size), memory(memory), constant(constant) {
if (numComponents == 1)
type = componentType;
else {
std::stringstream s;
s << componentType << numComponents;
type = s.str();
}
}
const std::string& getName() const {
return name;
}
const std::string& getComponentType() const {
return componentType;
}
const std::string& getType() const {
return type;
}
int getNumComponents() const {
return numComponents;
}
int getSize() const {
return size;
}
hipDeviceptr_t& getMemory() {
return memory;
}
bool isConstant() const {
return constant;
}
private:
std::string name;
std::string componentType;
std::string type;
int size, numComponents;
hipDeviceptr_t memory;
bool constant;
};
} // namespace OpenMM } // namespace OpenMM
#endif /*OPENMM_HIPNONBONDEDUTILITIES_H_*/ #endif /*OPENMM_HIPNONBONDEDUTILITIES_H_*/
#ifndef OPENMM_HIPPARAMETERSET_H_
#define OPENMM_HIPPARAMETERSET_H_
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009-2019 Stanford University and the Authors. *
* Portions copyright (c) 2020 Advanced Micro Devices, Inc. *
* Authors: Peter Eastman, Nicholas Curtis *
* Contributors: *
* *
* This program is free software: you can redistribute it and/or modify *
* it under the terms of the GNU Lesser General Public License as published *
* by the Free Software Foundation, either version 3 of the License, or *
* (at your option) any later version. *
* *
* This program is distributed in the hope that it will be useful, *
* but WITHOUT ANY WARRANTY; without even the implied warranty of *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
* GNU Lesser General Public License for more details. *
* *
* You should have received a copy of the GNU Lesser General Public License *
* along with this program. If not, see <http://www.gnu.org/licenses/>. *
* -------------------------------------------------------------------------- */
#include "HipContext.h"
#include "HipNonbondedUtilities.h"
#include "openmm/common/ComputeParameterSet.h"
namespace OpenMM {
class HipNonbondedUtilities;
/**
* This class exists for backward compatibility. For most purposes you can use
* ComputeParameterSet directly instead.
*/
class OPENMM_EXPORT_COMMON HipParameterSet : public ComputeParameterSet {
public:
/**
* Create an HipParameterSet.
*
* @param context the context for which to create the parameter set
* @param numParameters the number of parameters for each object
* @param numObjects the number of objects to store parameter values for
* @param name the name of the parameter set
* @param bufferPerParameter if true, a separate buffer is created for each parameter. If false,
* multiple parameters may be combined into a single buffer.
* @param useDoublePrecision whether values should be stored as single or double precision
*/
HipParameterSet(HipContext& context, int numParameters, int numObjects, const std::string& name, bool bufferPerParameter=false, bool useDoublePrecision=false);
/**
* Get a set of HipNonbondedUtilities::ParameterInfo objects which describe the Buffers
* containing the data.
*/
std::vector<HipNonbondedUtilities::ParameterInfo>& getBuffers() {
return buffers;
}
private:
std::vector<HipNonbondedUtilities::ParameterInfo> buffers;
};
} // namespace OpenMM
#endif /*OPENMM_HIPPARAMETERSET_H_*/
...@@ -63,6 +63,7 @@ void HipCalcForcesAndEnergyKernel::beginComputation(ContextImpl& context, bool i ...@@ -63,6 +63,7 @@ void HipCalcForcesAndEnergyKernel::beginComputation(ContextImpl& context, bool i
cu.setForcesValid(true); cu.setForcesValid(true);
ContextSelector selector(cu); ContextSelector selector(cu);
cu.clearAutoclearBuffers(); cu.clearAutoclearBuffers();
cu.updateGlobalParamValues();
for (auto computation : cu.getPreComputations()) for (auto computation : cu.getPreComputations())
computation->computeForceAndEnergy(includeForces, includeEnergy, groups); computation->computeForceAndEnergy(includeForces, includeEnergy, groups);
HipNonbondedUtilities& nb = cu.getNonbondedUtilities(); HipNonbondedUtilities& nb = cu.getNonbondedUtilities();
......
...@@ -115,20 +115,10 @@ void HipNonbondedUtilities::addInteraction(bool usesCutoff, bool usesPeriodic, b ...@@ -115,20 +115,10 @@ void HipNonbondedUtilities::addInteraction(bool usesCutoff, bool usesPeriodic, b
} }
void HipNonbondedUtilities::addParameter(ComputeParameterInfo parameter) { void HipNonbondedUtilities::addParameter(ComputeParameterInfo parameter) {
parameters.push_back(ParameterInfo(parameter.getName(), parameter.getComponentType(), parameter.getNumComponents(),
parameter.getSize(), context.unwrap(parameter.getArray()).getDevicePointer(), parameter.isConstant()));
}
void HipNonbondedUtilities::addParameter(const ParameterInfo& parameter) {
parameters.push_back(parameter); parameters.push_back(parameter);
} }
void HipNonbondedUtilities::addArgument(ComputeParameterInfo parameter) { void HipNonbondedUtilities::addArgument(ComputeParameterInfo parameter) {
arguments.push_back(ParameterInfo(parameter.getName(), parameter.getComponentType(), parameter.getNumComponents(),
parameter.getSize(), context.unwrap(parameter.getArray()).getDevicePointer(), parameter.isConstant()));
}
void HipNonbondedUtilities::addArgument(const ParameterInfo& parameter) {
arguments.push_back(parameter); arguments.push_back(parameter);
} }
...@@ -333,10 +323,10 @@ void HipNonbondedUtilities::initialize(const System& system) { ...@@ -333,10 +323,10 @@ void HipNonbondedUtilities::initialize(const System& system) {
forceArgs.push_back(&maxSinglePairs); forceArgs.push_back(&maxSinglePairs);
forceArgs.push_back(&singlePairs.getDevicePointer()); forceArgs.push_back(&singlePairs.getDevicePointer());
} }
for (int i = 0; i < (int) parameters.size(); i++) hasInitializedParams = false;
forceArgs.push_back(&parameters[i].getMemory()); paramStartIndex = forceArgs.size();
for (ParameterInfo& arg : arguments) for (int i = 0; i < parameters.size()+arguments.size(); i++)
forceArgs.push_back(&arg.getMemory()); forceArgs.push_back(NULL);
if (energyParameterDerivatives.size() > 0) if (energyParameterDerivatives.size() > 0)
forceArgs.push_back(&context.getEnergyParamDerivBuffer().getDevicePointer()); forceArgs.push_back(&context.getEnergyParamDerivBuffer().getDevicePointer());
if (useCutoff) { if (useCutoff) {
...@@ -445,6 +435,15 @@ void HipNonbondedUtilities::prepareInteractions(int forceGroups) { ...@@ -445,6 +435,15 @@ void HipNonbondedUtilities::prepareInteractions(int forceGroups) {
hipEventRecord(downloadCountEvent, context.getCurrentStream()); hipEventRecord(downloadCountEvent, context.getCurrentStream());
} }
void HipNonbondedUtilities::initParamArgs() {
int index = paramStartIndex;
for (ComputeParameterInfo& param : parameters)
forceArgs[index++] = &context.unwrap(param.getArray()).getDevicePointer();
for (ComputeParameterInfo& arg : arguments)
forceArgs[index++] = &context.unwrap(arg.getArray()).getDevicePointer();
hasInitializedParams = true;
}
void HipNonbondedUtilities::computeInteractions(int forceGroups, bool includeForces, bool includeEnergy) { void HipNonbondedUtilities::computeInteractions(int forceGroups, bool includeForces, bool includeEnergy) {
if ((forceGroups&groupFlags) == 0) if ((forceGroups&groupFlags) == 0)
return; return;
...@@ -453,6 +452,8 @@ void HipNonbondedUtilities::computeInteractions(int forceGroups, bool includeFor ...@@ -453,6 +452,8 @@ void HipNonbondedUtilities::computeInteractions(int forceGroups, bool includeFor
hipFunction_t& kernel = (includeForces ? (includeEnergy ? kernels.forceEnergyKernel : kernels.forceKernel) : kernels.energyKernel); hipFunction_t& kernel = (includeForces ? (includeEnergy ? kernels.forceEnergyKernel : kernels.forceKernel) : kernels.energyKernel);
if (kernel == NULL) if (kernel == NULL)
kernel = createInteractionKernel(kernels.source, parameters, arguments, true, true, forceGroups, includeForces, includeEnergy); kernel = createInteractionKernel(kernels.source, parameters, arguments, true, true, forceGroups, includeForces, includeEnergy);
if (!hasInitializedParams)
initParamArgs();
context.executeKernelFlat(kernel, &forceArgs[0], numForceThreadBlocks*forceThreadBlockSize, forceThreadBlockSize); context.executeKernelFlat(kernel, &forceArgs[0], numForceThreadBlocks*forceThreadBlockSize, forceThreadBlockSize);
} }
if (useNeighborList && numTiles > 0) { if (useNeighborList && numTiles > 0) {
...@@ -586,12 +587,12 @@ void HipNonbondedUtilities::createKernelsForGroups(int groups) { ...@@ -586,12 +587,12 @@ void HipNonbondedUtilities::createKernelsForGroups(int groups) {
groupKernels[groups] = kernels; groupKernels[groups] = kernels;
} }
hipFunction_t HipNonbondedUtilities::createInteractionKernel(const string& source, vector<ParameterInfo>& params, vector<ParameterInfo>& arguments, bool useExclusions, bool isSymmetric, int groups, bool includeForces, bool includeEnergy) { hipFunction_t HipNonbondedUtilities::createInteractionKernel(const string& source, vector<ComputeParameterInfo>& params, vector<ComputeParameterInfo>& arguments, bool useExclusions, bool isSymmetric, int groups, bool includeForces, bool includeEnergy) {
map<string, string> replacements; map<string, string> replacements;
replacements["COMPUTE_INTERACTION"] = source; replacements["COMPUTE_INTERACTION"] = source;
const string suffixes[] = {"x", "y", "z", "w"}; const string suffixes[] = {"x", "y", "z", "w"};
stringstream args; stringstream args;
for (const ParameterInfo& param : params) { for (const ComputeParameterInfo& param : params) {
args << ", "; args << ", ";
if (param.isConstant()) if (param.isConstant())
args << "const "; args << "const ";
...@@ -599,7 +600,7 @@ hipFunction_t HipNonbondedUtilities::createInteractionKernel(const string& sourc ...@@ -599,7 +600,7 @@ hipFunction_t HipNonbondedUtilities::createInteractionKernel(const string& sourc
args << "* __restrict__ global_"; args << "* __restrict__ global_";
args << param.getName(); args << param.getName();
} }
for (const ParameterInfo& arg : arguments) { for (const ComputeParameterInfo& arg : arguments) {
args << ", "; args << ", ";
if (arg.isConstant()) if (arg.isConstant())
args << "const "; args << "const ";
...@@ -612,7 +613,7 @@ hipFunction_t HipNonbondedUtilities::createInteractionKernel(const string& sourc ...@@ -612,7 +613,7 @@ hipFunction_t HipNonbondedUtilities::createInteractionKernel(const string& sourc
replacements["PARAMETER_ARGUMENTS"] = args.str(); replacements["PARAMETER_ARGUMENTS"] = args.str();
stringstream load1; stringstream load1;
for (const ParameterInfo& param : params) { for (const ComputeParameterInfo& param : params) {
load1 << param.getType(); load1 << param.getType();
load1 << " "; load1 << " ";
load1 << param.getName(); load1 << param.getName();
...@@ -629,7 +630,7 @@ hipFunction_t HipNonbondedUtilities::createInteractionKernel(const string& sourc ...@@ -629,7 +630,7 @@ hipFunction_t HipNonbondedUtilities::createInteractionKernel(const string& sourc
broadcastWarpData << "posq2.y = SHFL(shflPosq.y, j);\n"; broadcastWarpData << "posq2.y = SHFL(shflPosq.y, j);\n";
broadcastWarpData << "posq2.z = SHFL(shflPosq.z, j);\n"; broadcastWarpData << "posq2.z = SHFL(shflPosq.z, j);\n";
broadcastWarpData << "posq2.w = SHFL(shflPosq.w, j);\n"; broadcastWarpData << "posq2.w = SHFL(shflPosq.w, j);\n";
for (const ParameterInfo& param : params) { for (const ComputeParameterInfo& param : params) {
broadcastWarpData << param.getType() << " shfl" << param.getName() << ";\n"; broadcastWarpData << param.getType() << " shfl" << param.getName() << ";\n";
for (int j = 0; j < param.getNumComponents(); j++) { for (int j = 0; j < param.getNumComponents(); j++) {
if (param.getNumComponents() == 1) if (param.getNumComponents() == 1)
...@@ -642,22 +643,22 @@ hipFunction_t HipNonbondedUtilities::createInteractionKernel(const string& sourc ...@@ -642,22 +643,22 @@ hipFunction_t HipNonbondedUtilities::createInteractionKernel(const string& sourc
// Part 2. Defines for off-diagonal exclusions, and neighborlist tiles. // Part 2. Defines for off-diagonal exclusions, and neighborlist tiles.
stringstream declareLocal2; stringstream declareLocal2;
for (const ParameterInfo& param : params) for (const ComputeParameterInfo& param : params)
declareLocal2<<param.getType()<<" shfl"<<param.getName()<<";\n"; declareLocal2<<param.getType()<<" shfl"<<param.getName()<<";\n";
replacements["DECLARE_LOCAL_PARAMETERS"] = declareLocal2.str(); replacements["DECLARE_LOCAL_PARAMETERS"] = declareLocal2.str();
stringstream loadLocal2; stringstream loadLocal2;
for (const ParameterInfo& param : params) for (const ComputeParameterInfo& param : params)
loadLocal2<<"shfl"<<param.getName()<<" = global_"<<param.getName()<<"[j];\n"; loadLocal2<<"shfl"<<param.getName()<<" = global_"<<param.getName()<<"[j];\n";
replacements["LOAD_LOCAL_PARAMETERS_FROM_GLOBAL"] = loadLocal2.str(); replacements["LOAD_LOCAL_PARAMETERS_FROM_GLOBAL"] = loadLocal2.str();
stringstream load2j; stringstream load2j;
for (const ParameterInfo& param : params) for (const ComputeParameterInfo& param : params)
load2j<<param.getType()<<" "<<param.getName()<<"2 = shfl"<<param.getName()<<";\n"; load2j<<param.getType()<<" "<<param.getName()<<"2 = shfl"<<param.getName()<<";\n";
replacements["LOAD_ATOM2_PARAMETERS"] = load2j.str(); replacements["LOAD_ATOM2_PARAMETERS"] = load2j.str();
stringstream clearLocal; stringstream clearLocal;
for (const ParameterInfo& param : params) { for (const ComputeParameterInfo& param : params) {
clearLocal<<"shfl"; clearLocal<<"shfl";
clearLocal<<param.getName()<<" = "; clearLocal<<param.getName()<<" = ";
if (param.getNumComponents() == 1) if (param.getNumComponents() == 1)
...@@ -683,7 +684,7 @@ hipFunction_t HipNonbondedUtilities::createInteractionKernel(const string& sourc ...@@ -683,7 +684,7 @@ hipFunction_t HipNonbondedUtilities::createInteractionKernel(const string& sourc
stringstream shuffleWarpData; stringstream shuffleWarpData;
shuffleWarpData << "shflPosq = warpRotateLeft<TILE_SIZE>(shflPosq);\n"; shuffleWarpData << "shflPosq = warpRotateLeft<TILE_SIZE>(shflPosq);\n";
shuffleWarpData << "shflForce = warpRotateLeft<TILE_SIZE>(shflForce);\n"; shuffleWarpData << "shflForce = warpRotateLeft<TILE_SIZE>(shflForce);\n";
for (const ParameterInfo& param : params) { for (const ComputeParameterInfo& param : params) {
shuffleWarpData<<"shfl"<<param.getName()<<"=warpRotateLeft<TILE_SIZE>(shfl"<<param.getName()<<");\n"; shuffleWarpData<<"shfl"<<param.getName()<<"=warpRotateLeft<TILE_SIZE>(shfl"<<param.getName()<<");\n";
} }
replacements["SHUFFLE_WARP_DATA"] = shuffleWarpData.str(); replacements["SHUFFLE_WARP_DATA"] = shuffleWarpData.str();
......
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009-2019 Stanford University and the Authors. *
* Portions copyright (c) 2020 Advanced Micro Devices, Inc. *
* Authors: Peter Eastman, Nicholas Curtis *
* Contributors: *
* *
* This program is free software: you can redistribute it and/or modify *
* it under the terms of the GNU Lesser General Public License as published *
* by the Free Software Foundation, either version 3 of the License, or *
* (at your option) any later version. *
* *
* This program is distributed in the hope that it will be useful, *
* but WITHOUT ANY WARRANTY; without even the implied warranty of *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
* GNU Lesser General Public License for more details. *
* *
* You should have received a copy of the GNU Lesser General Public License *
* along with this program. If not, see <http://www.gnu.org/licenses/>. *
* -------------------------------------------------------------------------- */
#include "HipParameterSet.h"
using namespace OpenMM;
using namespace std;
HipParameterSet::HipParameterSet(HipContext& context, int numParameters, int numObjects, const string& name, bool bufferPerParameter, bool useDoublePrecision) :
ComputeParameterSet(context, numParameters, numObjects, name, bufferPerParameter, useDoublePrecision) {
for (auto& info : getParameterInfos())
buffers.push_back(HipNonbondedUtilities::ParameterInfo(info.getName(), info.getComponentType(), info.getNumComponents(), info.getSize(), context.unwrap(info.getArray()).getDevicePointer()));
}
...@@ -199,8 +199,8 @@ public: ...@@ -199,8 +199,8 @@ public:
/** /**
* Get the ContextImpl is ComputeContext is associated with. * Get the ContextImpl is ComputeContext is associated with.
*/ */
ContextImpl& getContextImpl() { ContextImpl* getContextImpl() {
return *platformData.context; return platformData.context;
} }
/** /**
* Get a workspace used for accumulating energy when a simulation is parallelized across * Get a workspace used for accumulating energy when a simulation is parallelized across
......
...@@ -67,7 +67,6 @@ class OpenCLContext; ...@@ -67,7 +67,6 @@ class OpenCLContext;
class OPENMM_EXPORT_COMMON OpenCLNonbondedUtilities : public NonbondedUtilities { class OPENMM_EXPORT_COMMON OpenCLNonbondedUtilities : public NonbondedUtilities {
public: public:
class ParameterInfo;
OpenCLNonbondedUtilities(OpenCLContext& context); OpenCLNonbondedUtilities(OpenCLContext& context);
~OpenCLNonbondedUtilities(); ~OpenCLNonbondedUtilities();
/** /**
...@@ -91,22 +90,10 @@ public: ...@@ -91,22 +90,10 @@ public:
* Add a per-atom parameter that the default interaction kernel may depend on. * Add a per-atom parameter that the default interaction kernel may depend on.
*/ */
void addParameter(ComputeParameterInfo parameter); void addParameter(ComputeParameterInfo parameter);
/**
* Add a per-atom parameter that the default interaction kernel may depend on.
*
* @deprecated Use the version that takes a ComputeParameterInfo instead.
*/
void addParameter(const ParameterInfo& parameter);
/** /**
* Add an array (other than a per-atom parameter) that should be passed as an argument to the default interaction kernel. * Add an array (other than a per-atom parameter) that should be passed as an argument to the default interaction kernel.
*/ */
void addArgument(ComputeParameterInfo parameter); void addArgument(ComputeParameterInfo parameter);
/**
* Add an array (other than a per-atom parameter) that should be passed as an argument to the default interaction kernel.
*
* @deprecated Use the version that takes a ComputeParameterInfo instead.
*/
void addArgument(const ParameterInfo& parameter);
/** /**
* Register that the interaction kernel will be computing the derivative of the potential energy * Register that the interaction kernel will be computing the derivative of the potential energy
* with respect to a parameter. * with respect to a parameter.
...@@ -294,7 +281,7 @@ public: ...@@ -294,7 +281,7 @@ public:
* @param includeForces whether this kernel should compute forces * @param includeForces whether this kernel should compute forces
* @param includeEnergy whether this kernel should compute potential energy * @param includeEnergy whether this kernel should compute potential energy
*/ */
cl::Kernel createInteractionKernel(const std::string& source, const std::vector<ParameterInfo>& params, const std::vector<ParameterInfo>& arguments, bool useExclusions, bool isSymmetric, int groups, bool includeForces, bool includeEnergy); cl::Kernel createInteractionKernel(const std::string& source, std::vector<ComputeParameterInfo>& params, std::vector<ComputeParameterInfo>& arguments, bool useExclusions, bool isSymmetric, int groups, bool includeForces, bool includeEnergy);
/** /**
* Create the set of kernels that will be needed for a particular combination of force groups. * Create the set of kernels that will be needed for a particular combination of force groups.
* *
...@@ -332,8 +319,8 @@ private: ...@@ -332,8 +319,8 @@ private:
cl::Buffer* pinnedCountBuffer; cl::Buffer* pinnedCountBuffer;
unsigned int* pinnedCountMemory; unsigned int* pinnedCountMemory;
std::vector<std::vector<int> > atomExclusions; std::vector<std::vector<int> > atomExclusions;
std::vector<ParameterInfo> parameters; std::vector<ComputeParameterInfo> parameters;
std::vector<ParameterInfo> arguments; std::vector<ComputeParameterInfo> arguments;
std::vector<std::string> energyParameterDerivatives; std::vector<std::string> energyParameterDerivatives;
std::map<int, double> groupCutoff; std::map<int, double> groupCutoff;
std::map<int, std::string> groupKernelSource; std::map<int, std::string> groupKernelSource;
...@@ -362,62 +349,6 @@ public: ...@@ -362,62 +349,6 @@ public:
cl::Kernel findInteractionsWithinBlocksKernel; cl::Kernel findInteractionsWithinBlocksKernel;
}; };
/**
* This class stores information about a per-atom parameter that may be used in a nonbonded kernel.
*/
class OpenCLNonbondedUtilities::ParameterInfo {
public:
/**
* Create a ParameterInfo object.
*
* @param name the name of the parameter
* @param type the data type of the parameter's components
* @param numComponents the number of components in the parameter
* @param size the size of the parameter in bytes
* @param memory the memory containing the parameter values
* @param constant whether the memory should be marked as constant
*/
ParameterInfo(const std::string& name, const std::string& componentType, int numComponents, int size, cl::Memory& memory, bool constant=true) :
name(name), componentType(componentType), numComponents(numComponents), size(size), memory(&memory), constant(constant) {
if (numComponents == 1)
type = componentType;
else {
std::stringstream s;
s << componentType << numComponents;
type = s.str();
}
}
const std::string& getName() const {
return name;
}
const std::string& getComponentType() const {
return componentType;
}
const std::string& getType() const {
return type;
}
int getNumComponents() const {
return numComponents;
}
int getSize() const {
return size;
}
cl::Memory& getMemory() const {
return *memory;
}
bool isConstant() const {
return constant;
}
private:
std::string name;
std::string componentType;
std::string type;
int size, numComponents;
cl::Memory* memory;
bool constant;
};
} // namespace OpenMM } // namespace OpenMM
#endif /*OPENMM_OPENCLNONBONDEDUTILITIES_H_*/ #endif /*OPENMM_OPENCLNONBONDEDUTILITIES_H_*/
#ifndef OPENMM_OPENCLPARAMETERSET_H_
#define OPENMM_OPENCLPARAMETERSET_H_
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009-2019 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* This program is free software: you can redistribute it and/or modify *
* it under the terms of the GNU Lesser General Public License as published *
* by the Free Software Foundation, either version 3 of the License, or *
* (at your option) any later version. *
* *
* This program is distributed in the hope that it will be useful, *
* but WITHOUT ANY WARRANTY; without even the implied warranty of *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
* GNU Lesser General Public License for more details. *
* *
* You should have received a copy of the GNU Lesser General Public License *
* along with this program. If not, see <http://www.gnu.org/licenses/>. *
* -------------------------------------------------------------------------- */
#include "OpenCLContext.h"
#include "OpenCLNonbondedUtilities.h"
#include "openmm/common/ComputeParameterSet.h"
namespace OpenMM {
class OpenCLNonbondedUtilities;
/**
* This class exists for backward compatibility. For most purposes you can use
* ComputeParameterSet directly instead.
*/
class OPENMM_EXPORT_COMMON OpenCLParameterSet : public ComputeParameterSet {
public:
/**
* Create an OpenCLParameterSet.
*
* @param context the context for which to create the parameter set
* @param numParameters the number of parameters for each object
* @param numObjects the number of objects to store parameter values for
* @param name the name of the parameter set
* @param bufferPerParameter if true, a separate cl::Buffer is created for each parameter. If false,
* multiple parameters may be combined into a single buffer.
* @param useDoublePrecision whether values should be stored as single or double precision
*/
OpenCLParameterSet(OpenCLContext& context, int numParameters, int numObjects, const std::string& name, bool bufferPerParameter=false, bool useDoublePrecision=false);
/**
* Get a set of OpenCLNonbondedUtilities::ParameterInfo objects which describe the Buffers
* containing the data.
*/
std::vector<OpenCLNonbondedUtilities::ParameterInfo>& getBuffers() {
return buffers;
}
/**
* Get a set of OpenCLNonbondedUtilities::ParameterInfo objects which describe the Buffers
* containing the data.
*/
const std::vector<OpenCLNonbondedUtilities::ParameterInfo>& getBuffers() const {
return buffers;
}
private:
std::vector<OpenCLNonbondedUtilities::ParameterInfo> buffers;
};
} // namespace OpenMM
#endif /*OPENMM_OPENCLPARAMETERSET_H_*/
...@@ -108,7 +108,7 @@ public: ...@@ -108,7 +108,7 @@ public:
class OPENMM_EXPORT_COMMON OpenCLPlatform::PlatformData { class OPENMM_EXPORT_COMMON OpenCLPlatform::PlatformData {
public: public:
PlatformData(const System& system, const std::string& platformPropValue, const std::string& deviceIndexProperty, const std::string& precisionProperty, PlatformData(const System& system, ContextImpl* context, const std::string& platformPropValue, const std::string& deviceIndexProperty, const std::string& precisionProperty,
const std::string& cpuPmeProperty, const std::string& pmeStreamProperty, int numThreads, ContextImpl* originalContext); const std::string& cpuPmeProperty, const std::string& pmeStreamProperty, int numThreads, ContextImpl* originalContext);
~PlatformData(); ~PlatformData();
void initializeContexts(const System& system); void initializeContexts(const System& system);
......
...@@ -55,6 +55,7 @@ void OpenCLCalcForcesAndEnergyKernel::initialize(const System& system) { ...@@ -55,6 +55,7 @@ void OpenCLCalcForcesAndEnergyKernel::initialize(const System& system) {
void OpenCLCalcForcesAndEnergyKernel::beginComputation(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) { void OpenCLCalcForcesAndEnergyKernel::beginComputation(ContextImpl& context, bool includeForces, bool includeEnergy, int groups) {
cl.setForcesValid(true); cl.setForcesValid(true);
cl.clearAutoclearBuffers(); cl.clearAutoclearBuffers();
cl.updateGlobalParamValues();
for (auto computation : cl.getPreComputations()) for (auto computation : cl.getPreComputations())
computation->computeForceAndEnergy(includeForces, includeEnergy, groups); computation->computeForceAndEnergy(includeForces, includeEnergy, groups);
OpenCLNonbondedUtilities& nb = cl.getNonbondedUtilities(); OpenCLNonbondedUtilities& nb = cl.getNonbondedUtilities();
......
...@@ -123,20 +123,10 @@ void OpenCLNonbondedUtilities::addInteraction(bool usesCutoff, bool usesPeriodic ...@@ -123,20 +123,10 @@ void OpenCLNonbondedUtilities::addInteraction(bool usesCutoff, bool usesPeriodic
} }
void OpenCLNonbondedUtilities::addParameter(ComputeParameterInfo parameter) { void OpenCLNonbondedUtilities::addParameter(ComputeParameterInfo parameter) {
parameters.push_back(ParameterInfo(parameter.getName(), parameter.getComponentType(), parameter.getNumComponents(),
parameter.getSize(), context.unwrap(parameter.getArray()).getDeviceBuffer(), parameter.isConstant()));
}
void OpenCLNonbondedUtilities::addParameter(const ParameterInfo& parameter) {
parameters.push_back(parameter); parameters.push_back(parameter);
} }
void OpenCLNonbondedUtilities::addArgument(ComputeParameterInfo parameter) { void OpenCLNonbondedUtilities::addArgument(ComputeParameterInfo parameter) {
arguments.push_back(ParameterInfo(parameter.getName(), parameter.getComponentType(), parameter.getNumComponents(),
parameter.getSize(), context.unwrap(parameter.getArray()).getDeviceBuffer(), parameter.isConstant()));
}
void OpenCLNonbondedUtilities::addArgument(const ParameterInfo& parameter) {
arguments.push_back(parameter); arguments.push_back(parameter);
} }
...@@ -589,13 +579,13 @@ void OpenCLNonbondedUtilities::createKernelsForGroups(int groups) { ...@@ -589,13 +579,13 @@ void OpenCLNonbondedUtilities::createKernelsForGroups(int groups) {
groupKernels[groups] = kernels; groupKernels[groups] = kernels;
} }
cl::Kernel OpenCLNonbondedUtilities::createInteractionKernel(const string& source, const vector<ParameterInfo>& params, const vector<ParameterInfo>& arguments, bool useExclusions, bool isSymmetric, int groups, bool includeForces, bool includeEnergy) { cl::Kernel OpenCLNonbondedUtilities::createInteractionKernel(const string& source, vector<ComputeParameterInfo>& params, vector<ComputeParameterInfo>& arguments, bool useExclusions, bool isSymmetric, int groups, bool includeForces, bool includeEnergy) {
map<string, string> replacements; map<string, string> replacements;
replacements["COMPUTE_INTERACTION"] = source; replacements["COMPUTE_INTERACTION"] = source;
const string suffixes[] = {"x", "y", "z", "w"}; const string suffixes[] = {"x", "y", "z", "w"};
stringstream localData; stringstream localData;
int localDataSize = 0; int localDataSize = 0;
for (const ParameterInfo& param : params) { for (const ComputeParameterInfo& param : params) {
if (param.getNumComponents() == 1) if (param.getNumComponents() == 1)
localData<<param.getType()<<" "<<param.getName()<<";\n"; localData<<param.getType()<<" "<<param.getName()<<";\n";
else { else {
...@@ -606,7 +596,7 @@ cl::Kernel OpenCLNonbondedUtilities::createInteractionKernel(const string& sourc ...@@ -606,7 +596,7 @@ cl::Kernel OpenCLNonbondedUtilities::createInteractionKernel(const string& sourc
} }
replacements["ATOM_PARAMETER_DATA"] = localData.str(); replacements["ATOM_PARAMETER_DATA"] = localData.str();
stringstream args; stringstream args;
for (const ParameterInfo& param : params) { for (const ComputeParameterInfo& param : params) {
args << ", __global "; args << ", __global ";
if (param.isConstant()) if (param.isConstant())
args << "const "; args << "const ";
...@@ -617,13 +607,13 @@ cl::Kernel OpenCLNonbondedUtilities::createInteractionKernel(const string& sourc ...@@ -617,13 +607,13 @@ cl::Kernel OpenCLNonbondedUtilities::createInteractionKernel(const string& sourc
args << "* restrict global_"; args << "* restrict global_";
args << param.getName(); args << param.getName();
} }
for (const ParameterInfo& arg : arguments) { for (ComputeParameterInfo& arg : arguments) {
if (arg.getMemory().getInfo<CL_MEM_TYPE>() == CL_MEM_OBJECT_IMAGE2D) { if (context.unwrap(arg.getArray()).getDeviceBuffer().getInfo<CL_MEM_TYPE>() == CL_MEM_OBJECT_IMAGE2D) {
args << ", __read_only image2d_t "; args << ", __read_only image2d_t ";
args << arg.getName(); args << arg.getName();
} }
else { else {
if ((arg.getMemory().getInfo<CL_MEM_FLAGS>() & CL_MEM_READ_ONLY) == 0) { if ((context.unwrap(arg.getArray()).getDeviceBuffer().getInfo<CL_MEM_FLAGS>() & CL_MEM_READ_ONLY) == 0) {
args << ", __global "; args << ", __global ";
if (arg.isConstant()) if (arg.isConstant())
args << "const "; args << "const ";
...@@ -639,7 +629,7 @@ cl::Kernel OpenCLNonbondedUtilities::createInteractionKernel(const string& sourc ...@@ -639,7 +629,7 @@ cl::Kernel OpenCLNonbondedUtilities::createInteractionKernel(const string& sourc
args << ", __global mixed* restrict energyParamDerivs"; args << ", __global mixed* restrict energyParamDerivs";
replacements["PARAMETER_ARGUMENTS"] = args.str(); replacements["PARAMETER_ARGUMENTS"] = args.str();
stringstream loadLocal1; stringstream loadLocal1;
for (const ParameterInfo& param : params) { for (const ComputeParameterInfo& param : params) {
if (param.getNumComponents() == 1) { if (param.getNumComponents() == 1) {
loadLocal1<<"localData[localAtomIndex]."<<param.getName()<<" = "<<param.getName()<<"1;\n"; loadLocal1<<"localData[localAtomIndex]."<<param.getName()<<" = "<<param.getName()<<"1;\n";
} }
...@@ -651,7 +641,7 @@ cl::Kernel OpenCLNonbondedUtilities::createInteractionKernel(const string& sourc ...@@ -651,7 +641,7 @@ cl::Kernel OpenCLNonbondedUtilities::createInteractionKernel(const string& sourc
replacements["LOAD_LOCAL_PARAMETERS_FROM_1"] = loadLocal1.str(); replacements["LOAD_LOCAL_PARAMETERS_FROM_1"] = loadLocal1.str();
replacements["DECLARE_LOCAL_PARAMETERS"] = ""; replacements["DECLARE_LOCAL_PARAMETERS"] = "";
stringstream loadLocal2; stringstream loadLocal2;
for (const ParameterInfo& param : params) { for (const ComputeParameterInfo& param : params) {
if (param.getNumComponents() == 1) { if (param.getNumComponents() == 1) {
loadLocal2<<"localData[localAtomIndex]."<<param.getName()<<" = global_"<<param.getName()<<"[j];\n"; loadLocal2<<"localData[localAtomIndex]."<<param.getName()<<" = global_"<<param.getName()<<"[j];\n";
} }
...@@ -666,7 +656,7 @@ cl::Kernel OpenCLNonbondedUtilities::createInteractionKernel(const string& sourc ...@@ -666,7 +656,7 @@ cl::Kernel OpenCLNonbondedUtilities::createInteractionKernel(const string& sourc
} }
replacements["LOAD_LOCAL_PARAMETERS_FROM_GLOBAL"] = loadLocal2.str(); replacements["LOAD_LOCAL_PARAMETERS_FROM_GLOBAL"] = loadLocal2.str();
stringstream load1; stringstream load1;
for (const ParameterInfo& param : params) { for (const ComputeParameterInfo& param : params) {
load1<<param.getType()<<" "<<param.getName()<<"1 = "; load1<<param.getType()<<" "<<param.getName()<<"1 = ";
if (param.getNumComponents() == 3) if (param.getNumComponents() == 3)
load1<<"make_"<<param.getType()<<"(global_"<<param.getName()<<"[3*atom1], global_"<<param.getName()<<"[3*atom1+1], global_"<<param.getName()<<"[3*atom1+2]);\n"; load1<<"make_"<<param.getType()<<"(global_"<<param.getName()<<"[3*atom1], global_"<<param.getName()<<"[3*atom1+1], global_"<<param.getName()<<"[3*atom1+2]);\n";
...@@ -675,7 +665,7 @@ cl::Kernel OpenCLNonbondedUtilities::createInteractionKernel(const string& sourc ...@@ -675,7 +665,7 @@ cl::Kernel OpenCLNonbondedUtilities::createInteractionKernel(const string& sourc
} }
replacements["LOAD_ATOM1_PARAMETERS"] = load1.str(); replacements["LOAD_ATOM1_PARAMETERS"] = load1.str();
stringstream load2j; stringstream load2j;
for (const ParameterInfo& param : params) { for (const ComputeParameterInfo& param : params) {
if (param.getNumComponents() == 1) { if (param.getNumComponents() == 1) {
load2j<<param.getType()<<" "<<param.getName()<<"2 = localData[atom2]."<<param.getName()<<";\n"; load2j<<param.getType()<<" "<<param.getName()<<"2 = localData[atom2]."<<param.getName()<<";\n";
} }
...@@ -691,7 +681,7 @@ cl::Kernel OpenCLNonbondedUtilities::createInteractionKernel(const string& sourc ...@@ -691,7 +681,7 @@ cl::Kernel OpenCLNonbondedUtilities::createInteractionKernel(const string& sourc
} }
replacements["LOAD_ATOM2_PARAMETERS"] = load2j.str(); replacements["LOAD_ATOM2_PARAMETERS"] = load2j.str();
stringstream clearLocal; stringstream clearLocal;
for (const ParameterInfo& param : params) { for (const ComputeParameterInfo& param : params) {
if (param.getNumComponents() == 1) if (param.getNumComponents() == 1)
clearLocal<<"localData[localAtomIndex]."<<param.getName()<<" = 0;\n"; clearLocal<<"localData[localAtomIndex]."<<param.getName()<<" = 0;\n";
else else
...@@ -775,10 +765,10 @@ cl::Kernel OpenCLNonbondedUtilities::createInteractionKernel(const string& sourc ...@@ -775,10 +765,10 @@ cl::Kernel OpenCLNonbondedUtilities::createInteractionKernel(const string& sourc
kernel.setArg<cl::Buffer>(index++, blockBoundingBox.getDeviceBuffer()); kernel.setArg<cl::Buffer>(index++, blockBoundingBox.getDeviceBuffer());
kernel.setArg<cl::Buffer>(index++, interactingAtoms.getDeviceBuffer()); kernel.setArg<cl::Buffer>(index++, interactingAtoms.getDeviceBuffer());
} }
for (const ParameterInfo& param : params) for (ComputeParameterInfo& param : params)
kernel.setArg<cl::Memory>(index++, param.getMemory()); kernel.setArg<cl::Memory>(index++, context.unwrap(param.getArray()).getDeviceBuffer());
for (const ParameterInfo& arg : arguments) for (ComputeParameterInfo& arg : arguments)
kernel.setArg<cl::Memory>(index++, arg.getMemory()); kernel.setArg<cl::Memory>(index++, context.unwrap(arg.getArray()).getDeviceBuffer());
if (energyParameterDerivatives.size() > 0) if (energyParameterDerivatives.size() > 0)
kernel.setArg<cl::Memory>(index++, context.getEnergyParamDerivBuffer().getDeviceBuffer()); kernel.setArg<cl::Memory>(index++, context.getEnergyParamDerivBuffer().getDeviceBuffer());
return kernel; return kernel;
......
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2009-2012 Stanford University and the Authors. *
* Authors: Peter Eastman *
* Contributors: *
* *
* This program is free software: you can redistribute it and/or modify *
* it under the terms of the GNU Lesser General Public License as published *
* by the Free Software Foundation, either version 3 of the License, or *
* (at your option) any later version. *
* *
* This program is distributed in the hope that it will be useful, *
* but WITHOUT ANY WARRANTY; without even the implied warranty of *
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the *
* GNU Lesser General Public License for more details. *
* *
* You should have received a copy of the GNU Lesser General Public License *
* along with this program. If not, see <http://www.gnu.org/licenses/>. *
* -------------------------------------------------------------------------- */
#include "OpenCLParameterSet.h"
using namespace OpenMM;
using namespace std;
OpenCLParameterSet::OpenCLParameterSet(OpenCLContext& context, int numParameters, int numObjects, const string& name, bool bufferPerParameter, bool useDoublePrecision) :
ComputeParameterSet(context, numParameters, numObjects, name, bufferPerParameter, useDoublePrecision) {
for (auto& info : getParameterInfos()) {
buffers.push_back(OpenCLNonbondedUtilities::ParameterInfo(info.getName(), info.getComponentType(), info.getNumComponents(), info.getSize(), context.unwrap(info.getArray()).getDeviceBuffer()));
}
}
...@@ -196,7 +196,7 @@ void OpenCLPlatform::contextCreated(ContextImpl& context, const map<string, stri ...@@ -196,7 +196,7 @@ void OpenCLPlatform::contextCreated(ContextImpl& context, const map<string, stri
char* threadsEnv = getenv("OPENMM_CPU_THREADS"); char* threadsEnv = getenv("OPENMM_CPU_THREADS");
if (threadsEnv != NULL) if (threadsEnv != NULL)
stringstream(threadsEnv) >> threads; stringstream(threadsEnv) >> threads;
context.setPlatformData(new PlatformData(context.getSystem(), platformPropValue, devicePropValue, precisionPropValue, cpuPmePropValue, context.setPlatformData(new PlatformData(context.getSystem(), &context, platformPropValue, devicePropValue, precisionPropValue, cpuPmePropValue,
pmeStreamPropValue, threads, NULL)); pmeStreamPropValue, threads, NULL));
} }
...@@ -208,7 +208,7 @@ void OpenCLPlatform::linkedContextCreated(ContextImpl& context, ContextImpl& ori ...@@ -208,7 +208,7 @@ void OpenCLPlatform::linkedContextCreated(ContextImpl& context, ContextImpl& ori
string cpuPmePropValue = platform.getPropertyValue(originalContext.getOwner(), OpenCLUseCpuPme()); string cpuPmePropValue = platform.getPropertyValue(originalContext.getOwner(), OpenCLUseCpuPme());
string pmeStreamPropValue = platform.getPropertyValue(originalContext.getOwner(), OpenCLDisablePmeStream()); string pmeStreamPropValue = platform.getPropertyValue(originalContext.getOwner(), OpenCLDisablePmeStream());
int threads = reinterpret_cast<PlatformData*>(originalContext.getPlatformData())->threads.getNumThreads(); int threads = reinterpret_cast<PlatformData*>(originalContext.getPlatformData())->threads.getNumThreads();
context.setPlatformData(new PlatformData(context.getSystem(), platformPropValue, devicePropValue, precisionPropValue, cpuPmePropValue, context.setPlatformData(new PlatformData(context.getSystem(), &context, platformPropValue, devicePropValue, precisionPropValue, cpuPmePropValue,
pmeStreamPropValue, threads, &originalContext)); pmeStreamPropValue, threads, &originalContext));
} }
...@@ -217,9 +217,9 @@ void OpenCLPlatform::contextDestroyed(ContextImpl& context) const { ...@@ -217,9 +217,9 @@ void OpenCLPlatform::contextDestroyed(ContextImpl& context) const {
delete data; delete data;
} }
OpenCLPlatform::PlatformData::PlatformData(const System& system, const string& platformPropValue, const string& deviceIndexProperty, OpenCLPlatform::PlatformData::PlatformData(const System& system, ContextImpl* context, const string& platformPropValue, const string& deviceIndexProperty,
const string& precisionProperty, const string& cpuPmeProperty, const string& pmeStreamProperty, int numThreads, ContextImpl* originalContext) : const string& precisionProperty, const string& cpuPmeProperty, const string& pmeStreamProperty, int numThreads, ContextImpl* originalContext) :
removeCM(false), stepCount(0), computeForceCount(0), time(0.0), hasInitializedContexts(false), threads(numThreads) { context(context), removeCM(false), stepCount(0), computeForceCount(0), time(0.0), hasInitializedContexts(false), threads(numThreads) {
int platformIndex = -1; int platformIndex = -1;
if (platformPropValue.length() > 0) if (platformPropValue.length() > 0)
stringstream(platformPropValue) >> platformIndex; stringstream(platformPropValue) >> platformIndex;
......
...@@ -58,7 +58,7 @@ template <class Real2> ...@@ -58,7 +58,7 @@ template <class Real2>
void testTransform(bool realToComplex, int xsize, int ysize, int zsize) { void testTransform(bool realToComplex, int xsize, int ysize, int zsize) {
System system; System system;
system.addParticle(0.0); system.addParticle(0.0);
OpenCLPlatform::PlatformData platformData(system, "", "", platform.getPropertyDefaultValue("OpenCLPrecision"), "false", "false", 1, NULL); OpenCLPlatform::PlatformData platformData(system, NULL, "", "", platform.getPropertyDefaultValue("OpenCLPrecision"), "false", "false", 1, NULL);
OpenCLContext& context = *platformData.contexts[0]; OpenCLContext& context = *platformData.contexts[0];
context.initialize(); context.initialize();
OpenMM_SFMT::SFMT sfmt; OpenMM_SFMT::SFMT sfmt;
......
...@@ -54,7 +54,7 @@ void testGaussian() { ...@@ -54,7 +54,7 @@ void testGaussian() {
System system; System system;
for (int i = 0; i < numAtoms; i++) for (int i = 0; i < numAtoms; i++)
system.addParticle(1.0); system.addParticle(1.0);
OpenCLPlatform::PlatformData platformData(system, "", "", platform.getPropertyDefaultValue("OpenCLPrecision"), "false", "false", 1, NULL); OpenCLPlatform::PlatformData platformData(system, NULL, "", "", platform.getPropertyDefaultValue("OpenCLPrecision"), "false", "false", 1, NULL);
OpenCLContext& context = *platformData.contexts[0]; OpenCLContext& context = *platformData.contexts[0];
context.initialize(); context.initialize();
context.getIntegrationUtilities().initRandomNumberGenerator(0); context.getIntegrationUtilities().initRandomNumberGenerator(0);
......
...@@ -64,7 +64,7 @@ void verifySorting(vector<float> array, bool uniform) { ...@@ -64,7 +64,7 @@ void verifySorting(vector<float> array, bool uniform) {
System system; System system;
system.addParticle(0.0); system.addParticle(0.0);
OpenCLPlatform::PlatformData platformData(system, "", "", platform.getPropertyDefaultValue("OpenCLPrecision"), "false", "false", 1, NULL); OpenCLPlatform::PlatformData platformData(system, NULL, "", "", platform.getPropertyDefaultValue("OpenCLPrecision"), "false", "false", 1, NULL);
OpenCLContext& context = *platformData.contexts[0]; OpenCLContext& context = *platformData.contexts[0];
context.initialize(); context.initialize();
OpenCLArray data(context, array.size(), sizeof(float), "sortData"); OpenCLArray data(context, array.size(), sizeof(float), "sortData");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment