/* -------------------------------------------------------------------------- * * OpenMM * * -------------------------------------------------------------------------- * * This is part of the OpenMM molecular simulation toolkit. * * See https://openmm.org/development. * * * * Portions copyright (c) 2008-2025 Stanford University and the Authors. * * Authors: Peter Eastman * * Contributors: * * * * This program is free software: you can redistribute it and/or modify * * it under the terms of the GNU Lesser General Public License as published * * by the Free Software Foundation, either version 3 of the License, or * * (at your option) any later version. * * * * This program is distributed in the hope that it will be useful, * * but WITHOUT ANY WARRANTY; without even the implied warranty of * * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * * GNU Lesser General Public License for more details. * * * * You should have received a copy of the GNU Lesser General Public License * * along with this program. If not, see . * * -------------------------------------------------------------------------- */ #include "openmm/common/CommonIntegrateCustomStepKernel.h" #include "openmm/common/CommonKernelUtilities.h" #include "openmm/common/ContextSelector.h" #include "openmm/common/ExpressionUtilities.h" #include "openmm/Context.h" #include "openmm/internal/ContextImpl.h" #include "ReferenceTabulatedFunction.h" #include "SimTKOpenMMUtilities.h" #include "CommonKernelSources.h" #include "lepton/CustomFunction.h" #include "lepton/ExpressionTreeNode.h" #include "lepton/Operation.h" #include "lepton/Parser.h" #include "lepton/ParsedExpression.h" using namespace OpenMM; using namespace std; using namespace Lepton; class CommonIntegrateCustomStepKernel::ReorderListener : public ComputeContext::ReorderListener { public: ReorderListener(ComputeContext& cc, vector& perDofValues, vector >& localPerDofValuesFloat, vector >& localPerDofValuesDouble, vector& deviceValuesAreCurrent) : cc(cc), perDofValues(perDofValues), localPerDofValuesFloat(localPerDofValuesFloat), localPerDofValuesDouble(localPerDofValuesDouble), deviceValuesAreCurrent(deviceValuesAreCurrent) { int numAtoms = cc.getNumAtoms(); lastAtomOrder.resize(numAtoms); for (int i = 0; i < numAtoms; i++) lastAtomOrder[i] = cc.getAtomIndex()[i]; } void execute() { // Reorder the per-DOF variables to reflect the new atom order. if (perDofValues.size() == 0) return; int numAtoms = cc.getNumAtoms(); const vector& order = cc.getAtomIndex(); for (int index = 0; index < perDofValues.size(); index++) { if (cc.getUseDoublePrecision() || cc.getUseMixedPrecision()) { if (deviceValuesAreCurrent[index]) perDofValues[index].download(localPerDofValuesDouble[index]); vector swap(numAtoms); for (int i = 0; i < numAtoms; i++) swap[lastAtomOrder[i]] = localPerDofValuesDouble[index][i]; for (int i = 0; i < numAtoms; i++) localPerDofValuesDouble[index][i] = swap[order[i]]; perDofValues[index].upload(localPerDofValuesDouble[index]); } else { if (deviceValuesAreCurrent[index]) perDofValues[index].download(localPerDofValuesFloat[index]); vector swap(numAtoms); for (int i = 0; i < numAtoms; i++) swap[lastAtomOrder[i]] = localPerDofValuesFloat[index][i]; for (int i = 0; i < numAtoms; i++) localPerDofValuesFloat[index][i] = swap[order[i]]; perDofValues[index].upload(localPerDofValuesFloat[index]); } deviceValuesAreCurrent[index] = true; } for (int i = 0; i < numAtoms; i++) lastAtomOrder[i] = order[i]; } private: ComputeContext& cc; vector& perDofValues; vector >& localPerDofValuesFloat; vector >& localPerDofValuesDouble; vector& deviceValuesAreCurrent; vector lastAtomOrder; }; class CommonIntegrateCustomStepKernel::DerivFunction : public CustomFunction { public: DerivFunction(map& energyParamDerivs, const string& param) : energyParamDerivs(energyParamDerivs), param(param) { } int getNumArguments() const { return 0; } double evaluate(const double* arguments) const { return energyParamDerivs[param]; } double evaluateDerivative(const double* arguments, const int* derivOrder) const { return 0; } CustomFunction* clone() const { return new DerivFunction(energyParamDerivs, param); } private: map& energyParamDerivs; string param; }; void CommonIntegrateCustomStepKernel::initialize(const System& system, const CustomIntegrator& integrator) { cc.initializeContexts(); ContextSelector selector(cc); cc.getIntegrationUtilities().initRandomNumberGenerator(integrator.getRandomNumberSeed()); numGlobalVariables = integrator.getNumGlobalVariables(); int elementSize = (cc.getUseDoublePrecision() || cc.getUseMixedPrecision() ? sizeof(double) : sizeof(float)); sumBuffer.initialize(cc, system.getNumParticles(), elementSize, "sumBuffer"); summedValue.initialize(cc, 1, elementSize, "summedValue"); perDofValues.resize(integrator.getNumPerDofVariables()); localPerDofValuesFloat.resize(perDofValues.size()); localPerDofValuesDouble.resize(perDofValues.size()); for (int i = 0; i < perDofValues.size(); i++) perDofValues[i].initialize(cc, system.getNumParticles(), 4*elementSize, "perDofVariables"); localValuesAreCurrent.resize(integrator.getNumPerDofVariables(), false); deviceValuesAreCurrent.resize(integrator.getNumPerDofVariables(), false); cc.addReorderListener(new ReorderListener(cc, perDofValues, localPerDofValuesFloat, localPerDofValuesDouble, deviceValuesAreCurrent)); SimTKOpenMMUtilities::setRandomNumberSeed(integrator.getRandomNumberSeed()); } string CommonIntegrateCustomStepKernel::createPerDofComputation(const string& variable, const Lepton::ParsedExpression& expr, CustomIntegrator& integrator, const string& forceName, const string& energyName, vector& functions, vector >& functionNames) { string tempType = (cc.getSupportsDoublePrecision() ? "double3" : "float3"); map expressions; expressions[tempType+" tempResult = "] = expr; map variables; variables["x"] = "make_"+tempType+"(position.x, position.y, position.z)"; variables["v"] = "make_"+tempType+"(velocity.x, velocity.y, velocity.z)"; variables[forceName] = "make_"+tempType+"(f.x, f.y, f.z)"; variables["gaussian"] = "make_"+tempType+"(gaussian.x, gaussian.y, gaussian.z)"; variables["uniform"] = "make_"+tempType+"(uniform.x, uniform.y, uniform.z)"; variables["m"] = "mass"; variables["dt"] = "stepSize"; if (energyName != "") variables[energyName] = "make_"+tempType+"(energy)"; for (int i = 0; i < integrator.getNumGlobalVariables(); i++) variables[integrator.getGlobalVariableName(i)] = "make_"+tempType+"(globals["+cc.intToString(globalVariableIndex[i])+"])"; for (int i = 0; i < integrator.getNumPerDofVariables(); i++) variables[integrator.getPerDofVariableName(i)] = "convertToTempType3(perDof"+cc.intToString(i)+")"; for (int i = 0; i < (int) parameterNames.size(); i++) variables[parameterNames[i]] = "make_"+tempType+"(globals["+cc.intToString(parameterVariableIndex[i])+"])"; vector > variableNodes; findExpressionsForDerivs(expr.getRootNode(), variableNodes); for (auto& var : variables) variableNodes.push_back(make_pair(ExpressionTreeNode(new Operation::Variable(var.first)), var.second)); string result = cc.getExpressionUtilities().createExpressions(expressions, variableNodes, functions, functionNames, "temp", tempType); if (variable == "x") result += "position.x = tempResult.x; position.y = tempResult.y; position.z = tempResult.z;\n"; else if (variable == "v") result += "velocity.x = tempResult.x; velocity.y = tempResult.y; velocity.z = tempResult.z;\n"; else if (variable == "") result += "sum[index] = tempResult.x+tempResult.y+tempResult.z;\n"; else { for (int i = 0; i < integrator.getNumPerDofVariables(); i++) if (variable == integrator.getPerDofVariableName(i)) { string varName = "perDof"+cc.intToString(i); result += varName+".x = tempResult.x; "+varName+".y = tempResult.y; "+varName+".z = tempResult.z;\n"; } } return result; } void CommonIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid) { ContextSelector selector(cc); IntegrationUtilities& integration = cc.getIntegrationUtilities(); int numAtoms = cc.getNumAtoms(); int numSteps = integrator.getNumComputations(); bool useDouble = cc.getUseDoublePrecision() || cc.getUseMixedPrecision(); string tempType = (cc.getSupportsDoublePrecision() ? "double3" : "float3"); string perDofType = (useDouble ? "double4" : "float4"); if (!hasInitializedKernels) { hasInitializedKernels = true; // Initialize various data structures. const map& params = context.getParameters(); for (auto& param : params) parameterNames.push_back(param.first); kernels.resize(integrator.getNumComputations()); requiredGaussian.resize(integrator.getNumComputations(), 0); requiredUniform.resize(integrator.getNumComputations(), 0); needsGlobals.resize(numSteps, false); globalExpressions.resize(numSteps); stepType.resize(numSteps); stepTarget.resize(numSteps); merged.resize(numSteps, false); modifiesParameters = false; sumWorkGroupSize = cc.getMaxThreadBlockSize(); if (sumWorkGroupSize > 512) sumWorkGroupSize = 512; map defines; defines["NUM_ATOMS"] = cc.intToString(cc.getNumAtoms()); defines["PADDED_NUM_ATOMS"] = cc.intToString(cc.getPaddedNumAtoms()); defines["WORK_GROUP_SIZE"] = cc.intToString(sumWorkGroupSize); // Record the tabulated functions. map functions; vector > functionNames; vector functionList; vector tableTypes; tabulatedFunctions.resize(integrator.getNumTabulatedFunctions()); for (int i = 0; i < integrator.getNumTabulatedFunctions(); i++) { functionList.push_back(&integrator.getTabulatedFunction(i)); string name = integrator.getTabulatedFunctionName(i); string arrayName = "table"+cc.intToString(i); functionNames.push_back(make_pair(name, arrayName)); functions[name] = createReferenceTabulatedFunction(integrator.getTabulatedFunction(i)); int width; vector f = cc.getExpressionUtilities().computeFunctionCoefficients(integrator.getTabulatedFunction(i), width); tabulatedFunctions[i].initialize(cc, f.size(), "TabulatedFunction"); tabulatedFunctions[i].upload(f); if (width == 1) tableTypes.push_back("float"); else tableTypes.push_back("float"+cc.intToString(width)); } // Record information about all the computation steps. vector variable(numSteps); vector forceGroup; vector > expression; CustomIntegratorUtilities::analyzeComputations(context, integrator, expression, comparisons, blockEnd, invalidatesForces, needsForces, needsEnergy, computeBothForceAndEnergy, forceGroup, functions); for (int step = 0; step < numSteps; step++) { string expr; integrator.getComputationStep(step, stepType[step], variable[step], expr); if (stepType[step] == CustomIntegrator::WhileBlockStart) blockEnd[blockEnd[step]] = step; // Record where to branch back to. if (stepType[step] == CustomIntegrator::ComputeGlobal || stepType[step] == CustomIntegrator::IfBlockStart || stepType[step] == CustomIntegrator::WhileBlockStart) for (auto& expr : expression[step]) globalExpressions[step].push_back(ParsedExpression(replaceDerivFunctions(expr.getRootNode(), context)).createCompiledExpression()); } for (int step = 0; step < numSteps; step++) { for (auto& expr : globalExpressions[step]) expressionSet.registerExpression(expr); } // Record the indices for variables in the CompiledExpressionSet. gaussianVariableIndex = expressionSet.getVariableIndex("gaussian"); uniformVariableIndex = expressionSet.getVariableIndex("uniform"); dtVariableIndex = expressionSet.getVariableIndex("dt"); for (int i = 0; i < integrator.getNumGlobalVariables(); i++) globalVariableIndex.push_back(expressionSet.getVariableIndex(integrator.getGlobalVariableName(i))); for (auto& name : parameterNames) parameterVariableIndex.push_back(expressionSet.getVariableIndex(name)); // Record the variable names and flags for the force and energy in each step. forceGroupFlags.resize(numSteps, integrator.getIntegrationForceGroups()); vector forceGroupName; vector energyGroupName; for (int i = 0; i < 32; i++) { stringstream fname; fname << "f" << i; forceGroupName.push_back(fname.str()); stringstream ename; ename << "energy" << i; energyGroupName.push_back(ename.str()); } vector forceName(numSteps, "f"); vector energyName(numSteps, "energy"); stepEnergyVariableIndex.resize(numSteps, expressionSet.getVariableIndex("energy")); for (int step = 0; step < numSteps; step++) { if (needsForces[step] && forceGroup[step] > -1) forceName[step] = forceGroupName[forceGroup[step]]; if (needsEnergy[step] && forceGroup[step] > -1) { energyName[step] = energyGroupName[forceGroup[step]]; stepEnergyVariableIndex[step] = expressionSet.getVariableIndex(energyName[step]); } if (forceGroup[step] > -1) forceGroupFlags[step] = 1< 0) forceGroupFlags[step] = forceGroupFlags[step-1]; if (forceGroupFlags[step] != -2 && savedForces.find(forceGroupFlags[step]) == savedForces.end()) { savedForces[forceGroupFlags[step]] = ComputeArray(); savedForces[forceGroupFlags[step]].initialize(cc, cc.getLongForceBuffer().getSize(), cc.getLongForceBuffer().getElementSize(), "savedForces"); } } // Allocate space for storing global values, both on the host and the device. localGlobalValues.resize(expressionSet.getNumVariables()); int elementSize = (cc.getUseDoublePrecision() || cc.getUseMixedPrecision() ? sizeof(double) : sizeof(float)); globalValues.initialize(cc, expressionSet.getNumVariables(), elementSize, "globalValues"); for (int i = 0; i < integrator.getNumGlobalVariables(); i++) { localGlobalValues[globalVariableIndex[i]] = initialGlobalVariables[i]; expressionSet.setVariable(globalVariableIndex[i], initialGlobalVariables[i]); } for (int i = 0; i < (int) parameterVariableIndex.size(); i++) { double value = context.getParameter(parameterNames[i]); localGlobalValues[parameterVariableIndex[i]] = value; expressionSet.setVariable(parameterVariableIndex[i], value); } int numContextParams = context.getParameters().size(); localPerDofEnergyParamDerivs.resize(numContextParams); perDofEnergyParamDerivs.initialize(cc, max(1, numContextParams), elementSize, "perDofEnergyParamDerivs"); // Record information about the targets of steps that will be stored in global variables. for (int step = 0; step < numSteps; step++) { if (stepType[step] == CustomIntegrator::ComputeGlobal || stepType[step] == CustomIntegrator::ComputeSum) { if (variable[step] == "dt") stepTarget[step].type = DT; for (int i = 0; i < integrator.getNumGlobalVariables(); i++) if (variable[step] == integrator.getGlobalVariableName(i)) stepTarget[step].type = VARIABLE; for (auto& name : parameterNames) if (variable[step] == name) { stepTarget[step].type = PARAMETER; modifiesParameters = true; } stepTarget[step].variableIndex = expressionSet.getVariableIndex(variable[step]); } } // Identify which per-DOF steps are going to require global variables or context parameters. for (int step = 0; step < numSteps; step++) { if (stepType[step] == CustomIntegrator::ComputePerDof || stepType[step] == CustomIntegrator::ComputeSum) { for (int i = 0; i < integrator.getNumGlobalVariables(); i++) if (usesVariable(expression[step][0], integrator.getGlobalVariableName(i))) needsGlobals[step] = true; for (auto& name : parameterNames) if (usesVariable(expression[step][0], name)) needsGlobals[step] = true; } } // Determine how each step will represent the position (as just a value, or a value plus a delta). hasAnyConstraints = (context.getSystem().getNumConstraints() > 0); vector storePosAsDelta(numSteps, false); vector loadPosAsDelta(numSteps, false); if (hasAnyConstraints) { bool beforeConstrain = false; for (int step = numSteps-1; step >= 0; step--) { if (stepType[step] == CustomIntegrator::ConstrainPositions) beforeConstrain = true; else if (stepType[step] == CustomIntegrator::ComputePerDof && variable[step] == "x" && beforeConstrain) storePosAsDelta[step] = true; } bool storedAsDelta = false; for (int step = 0; step < numSteps; step++) { loadPosAsDelta[step] = storedAsDelta; if (storePosAsDelta[step] == true) storedAsDelta = true; if (stepType[step] == CustomIntegrator::ConstrainPositions) storedAsDelta = false; } } // Identify steps that can be merged into a single kernel. for (int step = 1; step < numSteps; step++) { if (invalidatesForces[step-1] || forceGroupFlags[step] != forceGroupFlags[step-1]) continue; if (stepType[step-1] == CustomIntegrator::ComputePerDof && stepType[step] == CustomIntegrator::ComputePerDof) merged[step] = true; } for (int step = numSteps-1; step > 0; step--) if (merged[step]) { needsForces[step-1] = (needsForces[step] || needsForces[step-1]); needsEnergy[step-1] = (needsEnergy[step] || needsEnergy[step-1]); needsGlobals[step-1] = (needsGlobals[step] || needsGlobals[step-1]); computeBothForceAndEnergy[step-1] = (computeBothForceAndEnergy[step] || computeBothForceAndEnergy[step-1]); } // Loop over all steps and create the kernels for them. for (int step = 0; step < numSteps; step++) { if ((stepType[step] == CustomIntegrator::ComputePerDof || stepType[step] == CustomIntegrator::ComputeSum) && !merged[step]) { // Compute a per-DOF value. stringstream compute; for (int i = 0; i < perDofValues.size(); i++) compute << tempType<<" perDof"< 0) compute << "float4 gaussian = gaussianValues[gaussianIndex+index];\n"; if (numUniform > 0) compute << "float4 uniform = uniformValues[uniformIndex+index];\n"; compute << createPerDofComputation(stepType[j] == CustomIntegrator::ComputePerDof ? variable[j] : "", expression[j][0], integrator, forceName[j], energyName[j], functionList, functionNames); if (variable[j] == "x") { if (storePosAsDelta[j]) { if (cc.getSupportsDoublePrecision()) compute << "posDelta[index] = convertFromDouble4(position-loadPos(posq, posqCorrection, index));\n"; else compute << "posDelta[index] = position-posq[index];\n"; } else compute << "storePos(posq, posqCorrection, index, position);\n"; } else if (variable[j] == "v") { if (cc.getSupportsDoublePrecision()) compute << "velm[index] = convertFromDouble4(velocity);\n"; else compute << "velm[index] = velocity;\n"; } else { for (int i = 0; i < perDofValues.size(); i++) compute << "perDofValues"< 0) compute << "gaussianIndex += NUM_ATOMS;\n"; if (numUniform > 0) compute << "uniformIndex += NUM_ATOMS;\n"; compute << "}\n"; } map replacements; replacements["COMPUTE_STEP"] = compute.str(); stringstream args; for (int i = 0; i < perDofValues.size(); i++) { string valueName = "perDofValues"+cc.intToString(i); args << ", GLOBAL " << perDofType << "* RESTRICT " << valueName; } for (int i = 0; i < (int) tableTypes.size(); i++) args << ", GLOBAL const " << tableTypes[i]<< "* RESTRICT table" << i; replacements["PARAMETER_ARGUMENTS"] = args.str(); if (loadPosAsDelta[step]) defines["LOAD_POS_AS_DELTA"] = "1"; else if (defines.find("LOAD_POS_AS_DELTA") != defines.end()) defines.erase("LOAD_POS_AS_DELTA"); ComputeProgram program = cc.compileProgram(cc.replaceStrings(CommonKernelSources::customIntegratorPerDof, replacements), defines); ComputeKernel kernel = program->createKernel("computePerDof"); kernels[step].push_back(kernel); requiredGaussian[step] = numGaussian; requiredUniform[step] = numUniform; kernel->addArg(cc.getPosq()); if (cc.getUseMixedPrecision()) kernel->addArg(cc.getPosqCorrection()); else kernel->addArg(nullptr); kernel->addArg(integration.getPosDelta()); kernel->addArg(cc.getVelm()); kernel->addArg(cc.getLongForceBuffer()); kernel->addArg(integration.getStepSize()); kernel->addArg(globalValues); kernel->addArg(sumBuffer); for (int i = 0; i < 4; i++) kernel->addArg(); kernel->addArg(perDofEnergyParamDerivs); for (auto& array : perDofValues) kernel->addArg(array); for (auto& array : tabulatedFunctions) kernel->addArg(array); if (stepType[step] == CustomIntegrator::ComputeSum) { // Create a second kernel for this step that sums the values. program = cc.compileProgram(CommonKernelSources::customIntegrator, defines); kernel = program->createKernel(useDouble ? "computeDoubleSum" : "computeFloatSum"); kernels[step].push_back(kernel); kernel->addArg(sumBuffer); kernel->addArg(summedValue); kernel->addArg(numAtoms); } } else if (stepType[step] == CustomIntegrator::ConstrainPositions) { // Apply position constraints. ComputeProgram program = cc.compileProgram(CommonKernelSources::customIntegrator, defines); ComputeKernel kernel = program->createKernel("applyPositionDeltas"); kernels[step].push_back(kernel); kernel->addArg(cc.getPosq()); if (cc.getUseMixedPrecision()) kernel->addArg(cc.getPosqCorrection()); else kernel->addArg(nullptr); kernel->addArg(integration.getPosDelta()); } } // Initialize the random number generator. int maxUniformRandoms = 1; for (int required : requiredUniform) maxUniformRandoms = max(maxUniformRandoms, required); uniformRandoms.initialize(cc, maxUniformRandoms, "uniformRandoms"); randomSeed.initialize(cc, cc.getNumThreadBlocks()*64, "randomSeed"); vector seed(randomSeed.getSize()); int rseed = integrator.getRandomNumberSeed(); // A random seed of 0 means use a unique one if (rseed == 0) rseed = osrngseed(); unsigned int r = (unsigned int) (rseed+1); for (auto& s : seed) { s.x = r = (1664525*r + 1013904223) & 0xFFFFFFFF; s.y = r = (1664525*r + 1013904223) & 0xFFFFFFFF; s.z = r = (1664525*r + 1013904223) & 0xFFFFFFFF; s.w = r = (1664525*r + 1013904223) & 0xFFFFFFFF; } randomSeed.upload(seed); ComputeProgram randomProgram = cc.compileProgram(CommonKernelSources::customIntegrator, defines); randomKernel = randomProgram->createKernel("generateRandomNumbers"); randomKernel->addArg(maxUniformRandoms); randomKernel->addArg(uniformRandoms); randomKernel->addArg(randomSeed); // Create the kernel for computing kinetic energy. stringstream computeKE; for (int i = 0; i < perDofValues.size(); i++) computeKE << tempType<<" perDof"< replacements; replacements["COMPUTE_STEP"] = computeKE.str(); stringstream args; for (int i = 0; i < perDofValues.size(); i++) { string valueName = "perDofValues"+cc.intToString(i); args << ", GLOBAL " << perDofType << "* RESTRICT " << valueName; } for (int i = 0; i < (int) tableTypes.size(); i++) args << ", GLOBAL const " << tableTypes[i]<< "* RESTRICT table" << i; replacements["PARAMETER_ARGUMENTS"] = args.str(); if (defines.find("LOAD_POS_AS_DELTA") != defines.end()) defines.erase("LOAD_POS_AS_DELTA"); ComputeProgram program = cc.compileProgram(cc.replaceStrings(CommonKernelSources::customIntegratorPerDof, replacements), defines); kineticEnergyKernel = program->createKernel("computePerDof"); kineticEnergyKernel->addArg(cc.getPosq()); if (cc.getUseMixedPrecision()) kineticEnergyKernel->addArg(cc.getPosqCorrection()); else kineticEnergyKernel->addArg(nullptr); kineticEnergyKernel->addArg(integration.getPosDelta()); kineticEnergyKernel->addArg(cc.getVelm()); kineticEnergyKernel->addArg(cc.getLongForceBuffer()); kineticEnergyKernel->addArg(integration.getStepSize()); kineticEnergyKernel->addArg(globalValues); kineticEnergyKernel->addArg(sumBuffer); kineticEnergyKernel->addArg(); kineticEnergyKernel->addArg(); kineticEnergyKernel->addArg(uniformRandoms); if (cc.getUseDoublePrecision() || cc.getUseMixedPrecision()) kineticEnergyKernel->addArg(0.0); else kineticEnergyKernel->addArg(0.0f); kineticEnergyKernel->addArg(perDofEnergyParamDerivs); for (auto& array : perDofValues) kineticEnergyKernel->addArg(array); for (auto& array : tabulatedFunctions) kineticEnergyKernel->addArg(array); // Create a second kernel to sum the values. program = cc.compileProgram(CommonKernelSources::customIntegrator, defines); sumKineticEnergyKernel = program->createKernel(useDouble ? "computeDoubleSum" : "computeFloatSum"); sumKineticEnergyKernel->addArg(sumBuffer); sumKineticEnergyKernel->addArg(summedValue); sumKineticEnergyKernel->addArg(numAtoms); // Delete the custom functions. for (auto& function : functions) delete function.second; } // Make sure all values (variables, parameters, etc.) are up to date. for (int i = 0; i < perDofValues.size(); i++) { if (!deviceValuesAreCurrent[i]) { if (useDouble) perDofValues[i].upload(localPerDofValuesDouble[i]); else perDofValues[i].upload(localPerDofValuesFloat[i]); deviceValuesAreCurrent[i] = true; } localValuesAreCurrent[i] = false; } double stepSize = integrator.getStepSize(); recordGlobalValue(stepSize, GlobalTarget(DT, dtVariableIndex), integrator); for (int i = 0; i < (int) parameterNames.size(); i++) { double value = context.getParameter(parameterNames[i]); if (value != localGlobalValues[parameterVariableIndex[i]]) { expressionSet.setVariable(parameterVariableIndex[i], value); localGlobalValues[parameterVariableIndex[i]] = value; deviceGlobalsAreCurrent = false; } } } ExpressionTreeNode CommonIntegrateCustomStepKernel::replaceDerivFunctions(const ExpressionTreeNode& node, ContextImpl& context) { // This is called recursively to identify calls to the deriv() function inside global expressions, // and replace them with a custom function that returns the correct value. const Operation& op = node.getOperation(); if (op.getId() == Operation::CUSTOM && op.getName() == "deriv") { string param = node.getChildren()[1].getOperation().getName(); if (context.getParameters().find(param) == context.getParameters().end()) throw OpenMMException("The second argument to deriv() must be a context parameter"); needsEnergyParamDerivs = true; return ExpressionTreeNode(new Operation::Custom("deriv", new DerivFunction(energyParamDerivs, param))); } else { vector children; for (auto& child : node.getChildren()) children.push_back(replaceDerivFunctions(child, context)); return ExpressionTreeNode(op.clone(), children); } } void CommonIntegrateCustomStepKernel::findExpressionsForDerivs(const ExpressionTreeNode& node, vector >& variableNodes) { // This is called recursively to identify calls to the deriv() function inside per-DOF expressions, // and record the code to replace them with. const Operation& op = node.getOperation(); if (op.getId() == Operation::CUSTOM && op.getName() == "deriv") { string param = node.getChildren()[1].getOperation().getName(); int index; for (index = 0; index < perDofEnergyParamDerivNames.size() && param != perDofEnergyParamDerivNames[index]; index++) ; if (index == perDofEnergyParamDerivNames.size()) perDofEnergyParamDerivNames.push_back(param); string tempType = (cc.getSupportsDoublePrecision() ? "double3" : "float3"); variableNodes.push_back(make_pair(node, "make_"+tempType+"(energyParamDerivs["+cc.intToString(index)+"])")); needsEnergyParamDerivs = true; } else { for (auto& child : node.getChildren()) findExpressionsForDerivs(child, variableNodes); } } void CommonIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid) { ContextSelector selector(cc); prepareForComputation(context, integrator, forcesAreValid); IntegrationUtilities& integration = cc.getIntegrationUtilities(); int numAtoms = cc.getNumAtoms(); int numSteps = integrator.getNumComputations(); if (!forcesAreValid) savedEnergy.clear(); // Loop over computation steps in the integrator and execute them. for (int step = 0; step < numSteps; ) { int nextStep = step+1; int forceGroups = forceGroupFlags[step]; int lastForceGroups = context.getLastForceGroups(); bool haveForces = (!needsForces[step] || (forcesAreValid && lastForceGroups == forceGroups)); bool haveEnergy = (!needsEnergy[step] || savedEnergy.find(forceGroups) != savedEnergy.end()); if (!haveForces || !haveEnergy) { if (forcesAreValid) { if (savedForces.find(lastForceGroups) != savedForces.end() && validSavedForces.find(lastForceGroups) == validSavedForces.end()) { // The forces are still valid. We just need a different force group right now. Save the old // forces in case we need them again. cc.getLongForceBuffer().copyTo(savedForces[lastForceGroups]); validSavedForces.insert(lastForceGroups); } } else validSavedForces.clear(); // Recompute forces and/or energy. Figure out what is actually needed // between now and the next time they get invalidated again. bool computeForce = (needsForces[step] || computeBothForceAndEnergy[step]); bool computeEnergy = (needsEnergy[step] || computeBothForceAndEnergy[step]); if (!computeEnergy && validSavedForces.find(forceGroups) != validSavedForces.end()) { // We can just restore the forces we saved earlier. savedForces[forceGroups].copyTo(cc.getLongForceBuffer()); context.getLastForceGroups() = forceGroups; } else { recordChangedParameters(context); { ContextDeselector deselector(cc); energy = context.calcForcesAndEnergy(computeForce, computeEnergy, forceGroups); } savedEnergy[forceGroups] = energy; if (needsEnergyParamDerivs) { context.getEnergyParameterDerivatives(energyParamDerivs); if (perDofEnergyParamDerivNames.size() > 0) { for (int i = 0; i < perDofEnergyParamDerivNames.size(); i++) localPerDofEnergyParamDerivs[i] = energyParamDerivs[perDofEnergyParamDerivNames[i]]; perDofEnergyParamDerivs.upload(localPerDofEnergyParamDerivs, true); } } } forcesAreValid = true; } if (needsEnergy[step]) energy = savedEnergy[forceGroups]; if (needsGlobals[step] && !deviceGlobalsAreCurrent) { // Upload the global values to the device. globalValues.upload(localGlobalValues, true); deviceGlobalsAreCurrent = true; } bool stepInvalidatesForces = invalidatesForces[step]; if (stepType[step] == CustomIntegrator::ComputePerDof && !merged[step]) { kernels[step][0]->setArg(9, integration.prepareRandomNumbers(requiredGaussian[step])); kernels[step][0]->setArg(8, integration.getRandom()); kernels[step][0]->setArg(10, uniformRandoms); if (cc.getUseDoublePrecision() || cc.getUseMixedPrecision()) kernels[step][0]->setArg(11, energy); else kernels[step][0]->setArg(11, (float) energy); if (requiredUniform[step] > 0) randomKernel->execute(numAtoms, 64); kernels[step][0]->execute(numAtoms, 128); } else if (stepType[step] == CustomIntegrator::ComputeGlobal) { expressionSet.setVariable(uniformVariableIndex, SimTKOpenMMUtilities::getUniformlyDistributedRandomNumber()); expressionSet.setVariable(gaussianVariableIndex, SimTKOpenMMUtilities::getNormallyDistributedRandomNumber()); expressionSet.setVariable(stepEnergyVariableIndex[step], energy); recordGlobalValue(globalExpressions[step][0].evaluate(), stepTarget[step], integrator); } else if (stepType[step] == CustomIntegrator::ComputeSum) { kernels[step][0]->setArg(9, integration.prepareRandomNumbers(requiredGaussian[step])); kernels[step][0]->setArg(8, integration.getRandom()); kernels[step][0]->setArg(10, uniformRandoms); if (cc.getUseDoublePrecision() || cc.getUseMixedPrecision()) kernels[step][0]->setArg(11, energy); else kernels[step][0]->setArg(11, (float) energy); if (requiredUniform[step] > 0) randomKernel->execute(numAtoms, 64); cc.clearBuffer(sumBuffer); kernels[step][0]->execute(numAtoms, 128); kernels[step][1]->execute(sumWorkGroupSize, sumWorkGroupSize); if (cc.getUseDoublePrecision() || cc.getUseMixedPrecision()) { double value; summedValue.download(&value); recordGlobalValue(value, stepTarget[step], integrator); } else { float value; summedValue.download(&value); recordGlobalValue(value, stepTarget[step], integrator); } } else if (stepType[step] == CustomIntegrator::UpdateContextState) { recordChangedParameters(context); stepInvalidatesForces = context.updateContextState(); } else if (stepType[step] == CustomIntegrator::ConstrainPositions) { if (hasAnyConstraints) { cc.getIntegrationUtilities().applyConstraints(integrator.getConstraintTolerance()); kernels[step][0]->execute(numAtoms); } cc.getIntegrationUtilities().computeVirtualSites(); } else if (stepType[step] == CustomIntegrator::ConstrainVelocities) { cc.getIntegrationUtilities().applyVelocityConstraints(integrator.getConstraintTolerance()); } else if (stepType[step] == CustomIntegrator::IfBlockStart) { if (!evaluateCondition(step)) nextStep = blockEnd[step]+1; } else if (stepType[step] == CustomIntegrator::WhileBlockStart) { if (!evaluateCondition(step)) nextStep = blockEnd[step]+1; } else if (stepType[step] == CustomIntegrator::BlockEnd) { if (blockEnd[step] != -1) nextStep = blockEnd[step]; // Return to the start of a while block. } if (stepInvalidatesForces) { forcesAreValid = false; savedEnergy.clear(); } step = nextStep; } recordChangedParameters(context); // Update the time and step count. cc.setTime(cc.getTime()+integrator.getStepSize()); cc.setStepCount(cc.getStepCount()+1); cc.reorderAtoms(); if (cc.getAtomsWereReordered()) { forcesAreValid = false; validSavedForces.clear(); } // Reduce UI lag. flushPeriodically(cc); } bool CommonIntegrateCustomStepKernel::evaluateCondition(int step) { expressionSet.setVariable(uniformVariableIndex, SimTKOpenMMUtilities::getUniformlyDistributedRandomNumber()); expressionSet.setVariable(gaussianVariableIndex, SimTKOpenMMUtilities::getNormallyDistributedRandomNumber()); expressionSet.setVariable(stepEnergyVariableIndex[step], energy); double lhs = globalExpressions[step][0].evaluate(); double rhs = globalExpressions[step][1].evaluate(); switch (comparisons[step]) { case CustomIntegratorUtilities::EQUAL: return (lhs == rhs); case CustomIntegratorUtilities::LESS_THAN: return (lhs < rhs); case CustomIntegratorUtilities::GREATER_THAN: return (lhs > rhs); case CustomIntegratorUtilities::NOT_EQUAL: return (lhs != rhs); case CustomIntegratorUtilities::LESS_THAN_OR_EQUAL: return (lhs <= rhs); case CustomIntegratorUtilities::GREATER_THAN_OR_EQUAL: return (lhs >= rhs); } throw OpenMMException("Invalid comparison operator"); } double CommonIntegrateCustomStepKernel::computeKineticEnergy(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid) { ContextSelector selector(cc); prepareForComputation(context, integrator, forcesAreValid); cc.clearBuffer(sumBuffer); kineticEnergyKernel->setArg(8, cc.getIntegrationUtilities().getRandom()); kineticEnergyKernel->setArg(9, 0); kineticEnergyKernel->execute(cc.getNumAtoms()); sumKineticEnergyKernel->execute(sumWorkGroupSize, sumWorkGroupSize); if (cc.getUseDoublePrecision() || cc.getUseMixedPrecision()) { double ke; summedValue.download(&ke); return ke; } else { float ke; summedValue.download(&ke); return ke; } } void CommonIntegrateCustomStepKernel::recordGlobalValue(double value, GlobalTarget target, CustomIntegrator& integrator) { switch (target.type) { case DT: if (value != localGlobalValues[dtVariableIndex]) deviceGlobalsAreCurrent = false; expressionSet.setVariable(dtVariableIndex, value); localGlobalValues[dtVariableIndex] = value; cc.getIntegrationUtilities().setNextStepSize(value); integrator.setStepSize(value); break; case VARIABLE: case PARAMETER: expressionSet.setVariable(target.variableIndex, value); localGlobalValues[target.variableIndex] = value; deviceGlobalsAreCurrent = false; break; } } void CommonIntegrateCustomStepKernel::recordChangedParameters(ContextImpl& context) { if (!modifiesParameters) return; for (int i = 0; i < (int) parameterNames.size(); i++) { double value = context.getParameter(parameterNames[i]); if (value != localGlobalValues[parameterVariableIndex[i]]) context.setParameter(parameterNames[i], localGlobalValues[parameterVariableIndex[i]]); } } void CommonIntegrateCustomStepKernel::getGlobalVariables(ContextImpl& context, vector& values) const { if (!globalValues.isInitialized()) { // The data structures haven't been created yet, so just return the list of values that was given earlier. values = initialGlobalVariables; return; } values.resize(numGlobalVariables); for (int i = 0; i < numGlobalVariables; i++) values[i] = localGlobalValues[globalVariableIndex[i]]; } void CommonIntegrateCustomStepKernel::setGlobalVariables(ContextImpl& context, const vector& values) { if (numGlobalVariables == 0) return; if (!globalValues.isInitialized()) { // The data structures haven't been created yet, so just store the list of values. initialGlobalVariables = values; return; } for (int i = 0; i < numGlobalVariables; i++) { localGlobalValues[globalVariableIndex[i]] = values[i]; expressionSet.setVariable(globalVariableIndex[i], values[i]); } deviceGlobalsAreCurrent = false; } void CommonIntegrateCustomStepKernel::getPerDofVariable(ContextImpl& context, int variable, vector& values) const { ContextSelector selector(cc); values.resize(perDofValues[variable].getSize()); const vector& order = cc.getAtomIndex(); if (cc.getUseDoublePrecision() || cc.getUseMixedPrecision()) { if (!localValuesAreCurrent[variable]) { perDofValues[variable].download(localPerDofValuesDouble[variable]); localValuesAreCurrent[variable] = true; } for (int i = 0; i < (int) values.size(); i++) { values[order[i]][0] = localPerDofValuesDouble[variable][i].x; values[order[i]][1] = localPerDofValuesDouble[variable][i].y; values[order[i]][2] = localPerDofValuesDouble[variable][i].z; } } else { if (!localValuesAreCurrent[variable]) { perDofValues[variable].download(localPerDofValuesFloat[variable]); localValuesAreCurrent[variable] = true; } for (int i = 0; i < (int) values.size(); i++) { values[order[i]][0] = localPerDofValuesFloat[variable][i].x; values[order[i]][1] = localPerDofValuesFloat[variable][i].y; values[order[i]][2] = localPerDofValuesFloat[variable][i].z; } } } void CommonIntegrateCustomStepKernel::setPerDofVariable(ContextImpl& context, int variable, const vector& values) { const vector& order = cc.getAtomIndex(); localValuesAreCurrent[variable] = true; deviceValuesAreCurrent[variable] = false; if (cc.getUseDoublePrecision() || cc.getUseMixedPrecision()) { localPerDofValuesDouble[variable].resize(values.size()); for (int i = 0; i < (int) values.size(); i++) localPerDofValuesDouble[variable][i] = mm_double4(values[order[i]][0], values[order[i]][1], values[order[i]][2], 0); } else { localPerDofValuesFloat[variable].resize(values.size()); for (int i = 0; i < (int) values.size(); i++) localPerDofValuesFloat[variable][i] = mm_float4(values[order[i]][0], values[order[i]][1], values[order[i]][2], 0); } }