Commit ea4873f0 authored by peastman's avatar peastman
Browse files

Code cleanup to OpenCL platform

parent 9f37b18b
......@@ -182,7 +182,27 @@ public:
* Copy the values in a vector to the Buffer.
*/
template <class T>
void upload(const std::vector<T>& data, bool blocking = true) {
void upload(const std::vector<T>& data, bool blocking = true, bool convert = false) {
if (convert && data.size() == size && sizeof(T) != elementSize) {
if (sizeof(T) == 2*elementSize) {
// Convert values from double to single precision.
const double* d = reinterpret_cast<const double*>(&data[0]);
std::vector<float> v(elementSize*size/sizeof(float));
for (int i = 0; i < v.size(); i++)
v[i] = (float) d[i];
upload(&v[0], blocking);
return;
}
if (2*sizeof(T) == elementSize) {
// Convert values from single to double precision.
const float* d = reinterpret_cast<const float*>(&data[0]);
std::vector<double> v(elementSize*size/sizeof(double));
for (int i = 0; i < v.size(); i++)
v[i] = (double) d[i];
upload(&v[0], blocking);
return;
}
}
if (sizeof(T) != elementSize || data.size() != size)
throw OpenMMException("Error uploading array "+name+": The specified vector does not match the size of the array");
upload(&data[0], blocking);
......
......@@ -1581,10 +1581,8 @@ private:
mutable std::vector<std::vector<mm_double4> > localPerDofValuesDouble;
std::map<std::string, double> energyParamDerivs;
std::vector<std::string> perDofEnergyParamDerivNames;
std::vector<cl_float> localPerDofEnergyParamDerivsFloat;
std::vector<cl_double> localPerDofEnergyParamDerivsDouble;
std::vector<float> globalValuesFloat;
std::vector<double> globalValuesDouble;
std::vector<cl_double> localPerDofEnergyParamDerivs;
std::vector<double> localGlobalValues;
std::vector<double> initialGlobalVariables;
std::vector<std::vector<cl::Kernel> > kernels;
cl::Kernel randomKernel, kineticEnergyKernel, sumKineticEnergyKernel;
......
......@@ -769,18 +769,10 @@ double OpenCLContext::reduceEnergy() {
void OpenCLContext::setCharges(const vector<double>& charges) {
if (!chargeBuffer.isInitialized())
chargeBuffer.initialize(*this, numAtoms, useDoublePrecision ? sizeof(double) : sizeof(float), "chargeBuffer");
if (getUseDoublePrecision()) {
vector<double> c(numAtoms);
for (int i = 0; i < numAtoms; i++)
c[i] = charges[i];
chargeBuffer.upload(c);
}
else {
vector<float> c(numAtoms);
for (int i = 0; i < numAtoms; i++)
c[i] = (float) charges[i];
chargeBuffer.upload(c);
}
chargeBuffer.upload(c, true, true);
setChargesKernel.setArg<cl::Buffer>(0, chargeBuffer.getDeviceBuffer());
setChargesKernel.setArg<cl::Buffer>(1, posq.getDeviceBuffer());
setChargesKernel.setArg<cl::Buffer>(2, atomIndexDevice.getDeviceBuffer());
......
......@@ -395,12 +395,12 @@ OpenCLIntegrationUtilities::OpenCLIntegrationUtilities(OpenCLContext& context, c
vector<cl_int> atomConstraintsVec(ccmaAtomConstraints.getSize());
vector<cl_int> numAtomConstraintsVec(ccmaNumAtomConstraints.getSize());
vector<cl_int> constraintMatrixColumnVec(ccmaConstraintMatrixColumn.getSize());
if (context.getUseDoublePrecision() || context.getUseMixedPrecision()) {
ccmaDistance.initialize<mm_double4>(context, numCCMA, "CcmaDistance");
ccmaDelta1.initialize<cl_double>(context, numCCMA, "CcmaDelta1");
ccmaDelta2.initialize<cl_double>(context, numCCMA, "CcmaDelta2");
ccmaReducedMass.initialize<cl_double>(context, numCCMA, "CcmaReducedMass");
ccmaConstraintMatrixValue.initialize<cl_double>(context, numCCMA*maxRowElements, "ConstraintMatrixValue");
int elementSize = (context.getUseDoublePrecision() || context.getUseMixedPrecision() ? sizeof(cl_double) : sizeof(cl_float));
ccmaDistance.initialize(context, numCCMA, 4*elementSize, "CcmaDistance");
ccmaDelta1.initialize(context, numCCMA, elementSize, "CcmaDelta1");
ccmaDelta2.initialize(context, numCCMA, elementSize, "CcmaDelta2");
ccmaReducedMass.initialize(context, numCCMA, elementSize, "CcmaReducedMass");
ccmaConstraintMatrixValue.initialize(context, numCCMA*maxRowElements, elementSize, "ConstraintMatrixValue");
vector<mm_double4> distanceVec(ccmaDistance.getSize());
vector<cl_double> reducedMassVec(ccmaReducedMass.getSize());
vector<cl_double> constraintMatrixValueVec(ccmaConstraintMatrixValue.getSize());
......@@ -424,43 +424,9 @@ OpenCLIntegrationUtilities::OpenCLIntegrationUtilities(OpenCLContext& context, c
atomConstraintsVec[i+j*numAtoms] = (forward ? inverseOrder[atomConstraints[i][j]]+1 : -inverseOrder[atomConstraints[i][j]]-1);
}
}
ccmaDistance.upload(distanceVec);
ccmaReducedMass.upload(reducedMassVec);
ccmaConstraintMatrixValue.upload(constraintMatrixValueVec);
}
else {
ccmaDistance.initialize<mm_float4>(context, numCCMA, "CcmaDistance");
ccmaDelta1.initialize<cl_float>(context, numCCMA, "CcmaDelta1");
ccmaDelta2.initialize<cl_float>(context, numCCMA, "CcmaDelta2");
ccmaReducedMass.initialize<cl_float>(context, numCCMA, "CcmaReducedMass");
ccmaConstraintMatrixValue.initialize<cl_float>(context, numCCMA*maxRowElements, "ConstraintMatrixValue");
vector<mm_float4> distanceVec(ccmaDistance.getSize());
vector<cl_float> reducedMassVec(ccmaReducedMass.getSize());
vector<cl_float> constraintMatrixValueVec(ccmaConstraintMatrixValue.getSize());
for (int i = 0; i < numCCMA; i++) {
int index = constraintOrder[i];
int c = ccmaConstraints[index];
atomsVec[i].x = atom1[c];
atomsVec[i].y = atom2[c];
distanceVec[i].w = (float) distance[c];
reducedMassVec[i] = (float) (0.5/(1.0/system.getParticleMass(atom1[c])+1.0/system.getParticleMass(atom2[c])));
for (unsigned int j = 0; j < matrix[index].size(); j++) {
constraintMatrixColumnVec[i+j*numCCMA] = matrix[index][j].first;
constraintMatrixValueVec[i+j*numCCMA] = (float) matrix[index][j].second;
}
constraintMatrixColumnVec[i+matrix[index].size()*numCCMA] = numCCMA;
}
for (unsigned int i = 0; i < atomConstraints.size(); i++) {
numAtomConstraintsVec[i] = atomConstraints[i].size();
for (unsigned int j = 0; j < atomConstraints[i].size(); j++) {
bool forward = (atom1[ccmaConstraints[atomConstraints[i][j]]] == i);
atomConstraintsVec[i+j*numAtoms] = (forward ? inverseOrder[atomConstraints[i][j]]+1 : -inverseOrder[atomConstraints[i][j]]-1);
}
}
ccmaDistance.upload(distanceVec);
ccmaReducedMass.upload(reducedMassVec);
ccmaConstraintMatrixValue.upload(constraintMatrixValueVec);
}
ccmaDistance.upload(distanceVec, true, true);
ccmaReducedMass.upload(reducedMassVec, true, true);
ccmaConstraintMatrixValue.upload(constraintMatrixValueVec, true, true);
ccmaAtoms.upload(atomsVec);
ccmaAtomConstraints.upload(atomConstraintsVec);
ccmaNumAtomConstraints.upload(numAtomConstraintsVec);
......@@ -563,57 +529,21 @@ OpenCLIntegrationUtilities::OpenCLIntegrationUtilities(OpenCLContext& context, c
vsiteLocalCoordsAtoms.upload(vsiteLocalCoordsAtomVec);
vsiteLocalCoordsStartIndex.upload(vsiteLocalCoordsStartVec);
}
if (context.getUseDoublePrecision()) {
vsite2AvgWeights.initialize<mm_double2>(context, max(1, num2Avg), "vsite2AvgWeights");
vsite3AvgWeights.initialize<mm_double4>(context, max(1, num3Avg), "vsite3AvgWeights");
vsiteOutOfPlaneWeights.initialize<mm_double4>(context, max(1, numOutOfPlane), "vsiteOutOfPlaneWeights");
vsiteLocalCoordsWeights.initialize<cl_double>(context, max(1, (int) vsiteLocalCoordsWeightVec.size()), "vsiteLocalCoordsWeights");
vsiteLocalCoordsPos.initialize<mm_double4>(context, max(1, (int) vsiteLocalCoordsPosVec.size()), "vsiteLocalCoordsPos");
int elementSize = (context.getUseDoublePrecision() ? sizeof(cl_double) : sizeof(cl_float));
vsite2AvgWeights.initialize(context, max(1, num2Avg), 2*elementSize, "vsite2AvgWeights");
vsite3AvgWeights.initialize(context, max(1, num3Avg), 4*elementSize, "vsite3AvgWeights");
vsiteOutOfPlaneWeights.initialize(context, max(1, numOutOfPlane), 4*elementSize, "vsiteOutOfPlaneWeights");
vsiteLocalCoordsWeights.initialize(context, max(1, (int) vsiteLocalCoordsWeightVec.size()), elementSize, "vsiteLocalCoordsWeights");
vsiteLocalCoordsPos.initialize(context, max(1, (int) vsiteLocalCoordsPosVec.size()), 4*elementSize, "vsiteLocalCoordsPos");
if (num2Avg > 0)
vsite2AvgWeights.upload(vsite2AvgWeightVec);
vsite2AvgWeights.upload(vsite2AvgWeightVec, true, true);
if (num3Avg > 0)
vsite3AvgWeights.upload(vsite3AvgWeightVec);
vsite3AvgWeights.upload(vsite3AvgWeightVec, true, true);
if (numOutOfPlane > 0)
vsiteOutOfPlaneWeights.upload(vsiteOutOfPlaneWeightVec);
vsiteOutOfPlaneWeights.upload(vsiteOutOfPlaneWeightVec, true, true);
if (numLocalCoords > 0) {
vsiteLocalCoordsWeights.upload(vsiteLocalCoordsWeightVec);
vsiteLocalCoordsPos.upload(vsiteLocalCoordsPosVec);
}
}
else {
vsite2AvgWeights.initialize<mm_float2>(context, max(1, num2Avg), "vsite2AvgWeights");
vsite3AvgWeights.initialize<mm_float4>(context, max(1, num3Avg), "vsite3AvgWeights");
vsiteOutOfPlaneWeights.initialize<mm_float4>(context, max(1, numOutOfPlane), "vsiteOutOfPlaneWeights");
vsiteLocalCoordsWeights.initialize<cl_float>(context, max(1, (int) vsiteLocalCoordsWeightVec.size()), "vsiteLocalCoordsWeights");
vsiteLocalCoordsPos.initialize<mm_float4>(context, max(1, (int) vsiteLocalCoordsPosVec.size()), "vsiteLocalCoordsPos");
if (num2Avg > 0) {
vector<mm_float2> floatWeights(num2Avg);
for (int i = 0; i < num2Avg; i++)
floatWeights[i] = mm_float2((float) vsite2AvgWeightVec[i].x, (float) vsite2AvgWeightVec[i].y);
vsite2AvgWeights.upload(floatWeights);
}
if (num3Avg > 0) {
vector<mm_float4> floatWeights(num3Avg);
for (int i = 0; i < num3Avg; i++)
floatWeights[i] = mm_float4((float) vsite3AvgWeightVec[i].x, (float) vsite3AvgWeightVec[i].y, (float) vsite3AvgWeightVec[i].z, 0.0f);
vsite3AvgWeights.upload(floatWeights);
}
if (numOutOfPlane > 0) {
vector<mm_float4> floatWeights(numOutOfPlane);
for (int i = 0; i < numOutOfPlane; i++)
floatWeights[i] = mm_float4((float) vsiteOutOfPlaneWeightVec[i].x, (float) vsiteOutOfPlaneWeightVec[i].y, (float) vsiteOutOfPlaneWeightVec[i].z, 0.0f);
vsiteOutOfPlaneWeights.upload(floatWeights);
}
if (numLocalCoords > 0) {
vector<cl_float> floatWeights(vsiteLocalCoordsWeightVec.size());
for (int i = 0; i < (int) vsiteLocalCoordsWeightVec.size(); i++)
floatWeights[i] = (cl_float) vsiteLocalCoordsWeightVec[i];
vsiteLocalCoordsWeights.upload(floatWeights);
vector<mm_float4> floatPos(vsiteLocalCoordsPosVec.size());
for (int i = 0; i < (int) vsiteLocalCoordsPosVec.size(); i++)
floatPos[i] = mm_float4((float) vsiteLocalCoordsPosVec[i].x, (float) vsiteLocalCoordsPosVec[i].y, (float) vsiteLocalCoordsPosVec[i].z, 0.0f);
vsiteLocalCoordsPos.upload(floatPos);
}
vsiteLocalCoordsWeights.upload(vsiteLocalCoordsWeightVec, true, true);
vsiteLocalCoordsPos.upload(vsiteLocalCoordsPosVec, true, true);
}
// If multiple virtual sites depend on the same particle, make sure the force distribution
......
......@@ -1829,32 +1829,19 @@ void OpenCLCalcNonbondedForceKernel::initialize(const System& system, const Nonb
sc += bsplines_data[j]*cos(arg);
ss += bsplines_data[j]*sin(arg);
}
moduli[i] = (float) (sc*sc+ss*ss);
moduli[i] = sc*sc+ss*ss;
}
for (int i = 0; i < ndata; i++)
{
if (moduli[i] < 1.0e-7)
moduli[i] = (moduli[i-1]+moduli[i+1])*0.5f;
}
if (cl.getUseDoublePrecision()) {
if (dim == 0)
xmoduli->upload(moduli);
else if (dim == 1)
ymoduli->upload(moduli);
else
zmoduli->upload(moduli);
}
else {
vector<float> modulif(ndata);
for (int i = 0; i < ndata; i++)
modulif[i] = (float) moduli[i];
if (dim == 0)
xmoduli->upload(modulif);
xmoduli->upload(moduli, true, true);
else if (dim == 1)
ymoduli->upload(modulif);
ymoduli->upload(moduli, true, true);
else
zmoduli->upload(modulif);
}
zmoduli->upload(moduli, true, true);
}
}
}
......@@ -1872,14 +1859,7 @@ void OpenCLCalcNonbondedForceKernel::initialize(const System& system, const Nonb
replacements["CHARGE2"] = "posq2.w";
}
else {
if (cl.getUseDoublePrecision())
charges.upload(chargeVec);
else {
vector<float> c(charges.getSize());
for (int i = 0; i < c.size(); i++)
c[i] = (float) chargeVec[i];
charges.upload(c);
}
charges.upload(chargeVec, true, true);
replacements["CHARGE1"] = prefix+"charge1";
replacements["CHARGE2"] = prefix+"charge2";
}
......@@ -2342,16 +2322,8 @@ void OpenCLCalcNonbondedForceKernel::copyParametersToContext(ContextImpl& contex
sigmaEpsilonVector[i] = mm_float2(0,0);
if (usePosqCharges)
cl.setCharges(chargeVector);
else {
if (cl.getUseDoublePrecision())
charges.upload(chargeVector);
else {
vector<float> c(charges.getSize());
for (int i = 0; i < c.size(); i++)
c[i] = (float) chargeVector[i];
charges.upload(c);
}
}
else
charges.upload(chargeVector, true, true);
sigmaEpsilon.upload(sigmaEpsilonVector);
// Record the exceptions.
......@@ -3022,14 +2994,7 @@ void OpenCLCalcGBSAOBCForceKernel::initialize(const System& system, const GBSAOB
chargeVec[i] = charge;
paramsVector[i] = mm_float2((float) radius, (float) (scalingFactor*radius));
}
if (cl.getUseDoublePrecision())
charges.upload(chargeVec);
else {
vector<float> c(charges.getSize());
for (int i = 0; i < c.size(); i++)
c[i] = (float) chargeVec[i];
charges.upload(c);
}
charges.upload(chargeVec, true, true);
params.upload(paramsVector);
prefactor = -ONE_4PI_EPS0*((1.0/force.getSoluteDielectric())-(1.0/force.getSolventDielectric()));
surfaceAreaFactor = -6.0*4*M_PI*force.getSurfaceAreaEnergy();
......@@ -3196,14 +3161,7 @@ void OpenCLCalcGBSAOBCForceKernel::copyParametersToContext(ContextImpl& context,
}
for (int i = numParticles; i < cl.getPaddedNumAtoms(); i++)
paramsVector[i] = mm_float2(1,1);
if (cl.getUseDoublePrecision())
charges.upload(chargeVector);
else {
vector<float> c(charges.getSize());
for (int i = 0; i < c.size(); i++)
c[i] = (float) chargeVector[i];
charges.upload(c);
}
charges.upload(chargeVector, true, true);
params.upload(paramsVector);
// Mark that the current reordering may be invalid.
......@@ -5051,8 +5009,7 @@ void OpenCLCalcCustomCentroidBondForceKernel::initialize(const System& system, c
numGroups = force.getNumGroups();
vector<cl_int> groupParticleVec;
vector<cl_float> groupWeightVecFloat;
vector<cl_double> groupWeightVecDouble;
vector<cl_double> groupWeightVec;
vector<cl_int> groupOffsetVec;
groupOffsetVec.push_back(0);
for (int i = 0; i < numGroups; i++) {
......@@ -5064,27 +5021,19 @@ void OpenCLCalcCustomCentroidBondForceKernel::initialize(const System& system, c
}
vector<vector<double> > normalizedWeights;
CustomCentroidBondForceImpl::computeNormalizedWeights(force, system, normalizedWeights);
if (cl.getUseDoublePrecision()) {
for (int i = 0; i < numGroups; i++)
groupWeightVecDouble.insert(groupWeightVecDouble.end(), normalizedWeights[i].begin(), normalizedWeights[i].end());
}
else {
for (int i = 0; i < numGroups; i++)
for (int j = 0; j < normalizedWeights[i].size(); j++)
groupWeightVecFloat.push_back((float) normalizedWeights[i][j]);
}
groupWeightVec.insert(groupWeightVec.end(), normalizedWeights[i].begin(), normalizedWeights[i].end());
groupParticles.initialize<int>(cl, groupParticleVec.size(), "groupParticles");
groupParticles.upload(groupParticleVec);
if (cl.getUseDoublePrecision()) {
groupWeights.initialize<double>(cl, groupParticleVec.size(), "groupWeights");
groupWeights.upload(groupWeightVecDouble);
centerPositions.initialize<mm_double4>(cl, numGroups, "centerPositions");
}
else {
groupWeights.initialize<float>(cl, groupParticleVec.size(), "groupWeights");
groupWeights.upload(groupWeightVecFloat);
centerPositions.initialize<mm_float4>(cl, numGroups, "centerPositions");
}
groupWeights.upload(groupWeightVec, true, true);
groupOffsets.initialize<int>(cl, groupOffsetVec.size(), "groupOffsets");
groupOffsets.upload(groupOffsetVec);
groupForces.initialize<long long>(cl, numGroups*3, "groupForces");
......@@ -7011,18 +6960,10 @@ void OpenCLCalcRMSDForceKernel::recordParameters(const RMSDForce& force) {
// Upload them to the device.
particles.upload(particleVec);
if (cl.getUseDoublePrecision()) {
vector<mm_double4> pos;
for (Vec3 p : centeredPositions)
pos.push_back(mm_double4(p[0], p[1], p[2], 0));
referencePos.upload(pos);
}
else {
vector<mm_float4> pos;
for (Vec3 p : centeredPositions)
pos.push_back(mm_float4(p[0], p[1], p[2], 0));
referencePos.upload(pos);
}
referencePos.upload(pos, true, true);
// Record the sum of the norms of the reference positions.
......@@ -7241,20 +7182,11 @@ void OpenCLIntegrateLangevinStepKernel::execute(ContextImpl& context, const Lang
double vscale = exp(-stepSize*friction);
double fscale = (friction == 0 ? stepSize : (1-vscale)/friction);
double noisescale = sqrt(kT*(1-vscale*vscale));
if (cl.getUseDoublePrecision() || cl.getUseMixedPrecision()) {
vector<cl_double> p(params.getSize());
p[0] = vscale;
p[1] = fscale;
p[2] = noisescale;
params.upload(p);
}
else {
vector<cl_float> p(params.getSize());
p[0] = (cl_float) vscale;
p[1] = (cl_float) fscale;
p[2] = (cl_float) noisescale;
params.upload(p);
}
params.upload(p, true, true);
prevTemp = temperature;
prevFriction = friction;
prevStepSize = stepSize;
......@@ -7818,22 +7750,20 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context
// Allocate space for storing global values, both on the host and the device.
globalValuesFloat.resize(expressionSet.getNumVariables());
globalValuesDouble.resize(expressionSet.getNumVariables());
localGlobalValues.resize(expressionSet.getNumVariables());
int elementSize = (cl.getUseDoublePrecision() || cl.getUseMixedPrecision() ? sizeof(double) : sizeof(float));
globalValues.initialize(cl, expressionSet.getNumVariables(), elementSize, "globalValues");
for (int i = 0; i < integrator.getNumGlobalVariables(); i++) {
globalValuesDouble[globalVariableIndex[i]] = initialGlobalVariables[i];
localGlobalValues[globalVariableIndex[i]] = initialGlobalVariables[i];
expressionSet.setVariable(globalVariableIndex[i], initialGlobalVariables[i]);
}
for (int i = 0; i < (int) parameterVariableIndex.size(); i++) {
double value = context.getParameter(parameterNames[i]);
globalValuesDouble[parameterVariableIndex[i]] = value;
localGlobalValues[parameterVariableIndex[i]] = value;
expressionSet.setVariable(parameterVariableIndex[i], value);
}
int numContextParams = context.getParameters().size();
localPerDofEnergyParamDerivsFloat.resize(numContextParams);
localPerDofEnergyParamDerivsDouble.resize(numContextParams);
localPerDofEnergyParamDerivs.resize(numContextParams);
perDofEnergyParamDerivs.initialize(cl, max(1, numContextParams), elementSize, "perDofEnergyParamDerivs");
// Record information about the targets of steps that will be stored in global variables.
......@@ -8108,8 +8038,8 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context
recordGlobalValue(stepSize, GlobalTarget(DT, dtVariableIndex), integrator);
for (int i = 0; i < (int) parameterNames.size(); i++) {
double value = context.getParameter(parameterNames[i]);
if (value != globalValuesDouble[parameterVariableIndex[i]]) {
globalValuesDouble[parameterVariableIndex[i]] = value;
if (value != localGlobalValues[parameterVariableIndex[i]]) {
localGlobalValues[parameterVariableIndex[i]] = value;
deviceGlobalsAreCurrent = false;
}
}
......@@ -8203,16 +8133,9 @@ void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr
if (needsEnergyParamDerivs) {
context.getEnergyParameterDerivatives(energyParamDerivs);
if (perDofEnergyParamDerivNames.size() > 0) {
if (cl.getUseDoublePrecision() || cl.getUseMixedPrecision()) {
for (int i = 0; i < perDofEnergyParamDerivNames.size(); i++)
localPerDofEnergyParamDerivsDouble[i] = energyParamDerivs[perDofEnergyParamDerivNames[i]];
perDofEnergyParamDerivs.upload(localPerDofEnergyParamDerivsDouble);
}
else {
for (int i = 0; i < perDofEnergyParamDerivNames.size(); i++)
localPerDofEnergyParamDerivsFloat[i] = (float) energyParamDerivs[perDofEnergyParamDerivNames[i]];
perDofEnergyParamDerivs.upload(localPerDofEnergyParamDerivsFloat);
}
localPerDofEnergyParamDerivs[i] = energyParamDerivs[perDofEnergyParamDerivNames[i]];
perDofEnergyParamDerivs.upload(localPerDofEnergyParamDerivs, true, true);
}
}
forcesAreValid = true;
......@@ -8223,13 +8146,7 @@ void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr
if (needsGlobals[step] && !deviceGlobalsAreCurrent) {
// Upload the global values to the device.
if (cl.getUseDoublePrecision() || cl.getUseMixedPrecision())
globalValues.upload(globalValuesDouble);
else {
for (int j = 0; j < (int) globalValuesDouble.size(); j++)
globalValuesFloat[j] = (float) globalValuesDouble[j];
globalValues.upload(globalValuesFloat);
}
globalValues.upload(localGlobalValues, true, true);
}
bool stepInvalidatesForces = invalidatesForces[step];
if (stepType[step] == CustomIntegrator::ComputePerDof && !merged[step]) {
......@@ -8380,17 +8297,17 @@ double OpenCLIntegrateCustomStepKernel::computeKineticEnergy(ContextImpl& contex
void OpenCLIntegrateCustomStepKernel::recordGlobalValue(double value, GlobalTarget target, CustomIntegrator& integrator) {
switch (target.type) {
case DT:
if (value != globalValuesDouble[dtVariableIndex])
if (value != localGlobalValues[dtVariableIndex])
deviceGlobalsAreCurrent = false;
expressionSet.setVariable(dtVariableIndex, value);
globalValuesDouble[dtVariableIndex] = value;
localGlobalValues[dtVariableIndex] = value;
cl.getIntegrationUtilities().setNextStepSize(value);
integrator.setStepSize(value);
break;
case VARIABLE:
case PARAMETER:
expressionSet.setVariable(target.variableIndex, value);
globalValuesDouble[target.variableIndex] = value;
localGlobalValues[target.variableIndex] = value;
deviceGlobalsAreCurrent = false;
break;
}
......@@ -8401,8 +8318,8 @@ void OpenCLIntegrateCustomStepKernel::recordChangedParameters(ContextImpl& conte
return;
for (int i = 0; i < (int) parameterNames.size(); i++) {
double value = context.getParameter(parameterNames[i]);
if (value != globalValuesDouble[parameterVariableIndex[i]])
context.setParameter(parameterNames[i], globalValuesDouble[parameterVariableIndex[i]]);
if (value != localGlobalValues[parameterVariableIndex[i]])
context.setParameter(parameterNames[i], localGlobalValues[parameterVariableIndex[i]]);
}
}
......@@ -8415,7 +8332,7 @@ void OpenCLIntegrateCustomStepKernel::getGlobalVariables(ContextImpl& context, v
}
values.resize(numGlobalVariables);
for (int i = 0; i < numGlobalVariables; i++)
values[i] = globalValuesDouble[globalVariableIndex[i]];
values[i] = localGlobalValues[globalVariableIndex[i]];
}
void OpenCLIntegrateCustomStepKernel::setGlobalVariables(ContextImpl& context, const vector<double>& values) {
......@@ -8428,7 +8345,7 @@ void OpenCLIntegrateCustomStepKernel::setGlobalVariables(ContextImpl& context, c
return;
}
for (int i = 0; i < numGlobalVariables; i++) {
globalValuesDouble[globalVariableIndex[i]] = values[i];
localGlobalValues[globalVariableIndex[i]] = values[i];
expressionSet.setVariable(globalVariableIndex[i], values[i]);
}
deviceGlobalsAreCurrent = false;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment