Commit dbdf1c68 authored by peastman's avatar peastman
Browse files

OpenCL implementation of parameter derivatives for CustomIntegrator

parent 2a52e208
...@@ -1282,7 +1282,7 @@ public: ...@@ -1282,7 +1282,7 @@ public:
enum GlobalTargetType {DT, VARIABLE, PARAMETER}; enum GlobalTargetType {DT, VARIABLE, PARAMETER};
OpenCLIntegrateCustomStepKernel(std::string name, const Platform& platform, OpenCLContext& cl) : IntegrateCustomStepKernel(name, platform), cl(cl), OpenCLIntegrateCustomStepKernel(std::string name, const Platform& platform, OpenCLContext& cl) : IntegrateCustomStepKernel(name, platform), cl(cl),
hasInitializedKernels(false), localValuesAreCurrent(false), globalValues(NULL), sumBuffer(NULL), summedValue(NULL), uniformRandoms(NULL), hasInitializedKernels(false), localValuesAreCurrent(false), globalValues(NULL), sumBuffer(NULL), summedValue(NULL), uniformRandoms(NULL),
randomSeed(NULL), perDofValues(NULL) { randomSeed(NULL), perDofEnergyParamDerivs(NULL), perDofValues(NULL), needsEnergyParamDerivs(false) {
} }
~OpenCLIntegrateCustomStepKernel(); ~OpenCLIntegrateCustomStepKernel();
/** /**
...@@ -1347,8 +1347,11 @@ public: ...@@ -1347,8 +1347,11 @@ public:
private: private:
class ReorderListener; class ReorderListener;
class GlobalTarget; class GlobalTarget;
class DerivFunction;
std::string createPerDofComputation(const std::string& variable, const Lepton::ParsedExpression& expr, int component, CustomIntegrator& integrator, const std::string& forceName, const std::string& energyName); std::string createPerDofComputation(const std::string& variable, const Lepton::ParsedExpression& expr, int component, CustomIntegrator& integrator, const std::string& forceName, const std::string& energyName);
void prepareForComputation(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid); void prepareForComputation(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid);
Lepton::ExpressionTreeNode replaceDerivFunctions(const Lepton::ExpressionTreeNode& node, OpenMM::ContextImpl& context);
void findExpressionsForDerivs(const Lepton::ExpressionTreeNode& node, std::vector<std::pair<Lepton::ExpressionTreeNode, std::string> >& variableNodes);
void recordGlobalValue(double value, GlobalTarget target); void recordGlobalValue(double value, GlobalTarget target);
void recordChangedParameters(ContextImpl& context); void recordChangedParameters(ContextImpl& context);
bool evaluateCondition(int step); bool evaluateCondition(int step);
...@@ -1356,18 +1359,23 @@ private: ...@@ -1356,18 +1359,23 @@ private:
double energy; double energy;
float energyFloat; float energyFloat;
int numGlobalVariables; int numGlobalVariables;
bool hasInitializedKernels, deviceValuesAreCurrent, deviceGlobalsAreCurrent, modifiesParameters, keNeedsForce, hasAnyConstraints; bool hasInitializedKernels, deviceValuesAreCurrent, deviceGlobalsAreCurrent, modifiesParameters, keNeedsForce, hasAnyConstraints, needsEnergyParamDerivs;
mutable bool localValuesAreCurrent; mutable bool localValuesAreCurrent;
OpenCLArray* globalValues; OpenCLArray* globalValues;
OpenCLArray* sumBuffer; OpenCLArray* sumBuffer;
OpenCLArray* summedValue; OpenCLArray* summedValue;
OpenCLArray* uniformRandoms; OpenCLArray* uniformRandoms;
OpenCLArray* randomSeed; OpenCLArray* randomSeed;
OpenCLArray* perDofEnergyParamDerivs;
std::map<int, OpenCLArray*> savedForces; std::map<int, OpenCLArray*> savedForces;
std::set<int> validSavedForces; std::set<int> validSavedForces;
OpenCLParameterSet* perDofValues; OpenCLParameterSet* perDofValues;
mutable std::vector<std::vector<cl_float> > localPerDofValuesFloat; mutable std::vector<std::vector<cl_float> > localPerDofValuesFloat;
mutable std::vector<std::vector<cl_double> > localPerDofValuesDouble; mutable std::vector<std::vector<cl_double> > localPerDofValuesDouble;
std::map<std::string, double> energyParamDerivs;
std::vector<std::string> perDofEnergyParamDerivNames;
std::vector<cl_float> localPerDofEnergyParamDerivsFloat;
std::vector<cl_double> localPerDofEnergyParamDerivsDouble;
std::vector<float> globalValuesFloat; std::vector<float> globalValuesFloat;
std::vector<double> globalValuesDouble; std::vector<double> globalValuesDouble;
std::vector<double> initialGlobalVariables; std::vector<double> initialGlobalVariables;
......
...@@ -43,6 +43,7 @@ ...@@ -43,6 +43,7 @@
#include "OpenCLIntegrationUtilities.h" #include "OpenCLIntegrationUtilities.h"
#include "OpenCLNonbondedUtilities.h" #include "OpenCLNonbondedUtilities.h"
#include "OpenCLKernelSources.h" #include "OpenCLKernelSources.h"
#include "lepton/CustomFunction.h"
#include "lepton/ExpressionTreeNode.h" #include "lepton/ExpressionTreeNode.h"
#include "lepton/Operation.h" #include "lepton/Operation.h"
#include "lepton/Parser.h" #include "lepton/Parser.h"
...@@ -55,8 +56,7 @@ ...@@ -55,8 +56,7 @@
using namespace OpenMM; using namespace OpenMM;
using namespace std; using namespace std;
using Lepton::ExpressionTreeNode; using namespace Lepton;
using Lepton::Operation;
static void setPosqCorrectionArg(OpenCLContext& cl, cl::Kernel& kernel, int index) { static void setPosqCorrectionArg(OpenCLContext& cl, cl::Kernel& kernel, int index) {
if (cl.getUseMixedPrecision()) if (cl.getUseMixedPrecision())
...@@ -6675,6 +6675,27 @@ private: ...@@ -6675,6 +6675,27 @@ private:
vector<int> lastAtomOrder; vector<int> lastAtomOrder;
}; };
class OpenCLIntegrateCustomStepKernel::DerivFunction : public CustomFunction {
public:
DerivFunction(map<string, double>& energyParamDerivs, const string& param) : energyParamDerivs(energyParamDerivs), param(param) {
}
int getNumArguments() const {
return 0;
}
double evaluate(const double* arguments) const {
return energyParamDerivs[param];
}
double evaluateDerivative(const double* arguments, const int* derivOrder) const {
return 0;
}
CustomFunction* clone() const {
return new DerivFunction(energyParamDerivs, param);
}
private:
map<string, double>& energyParamDerivs;
string param;
};
OpenCLIntegrateCustomStepKernel::~OpenCLIntegrateCustomStepKernel() { OpenCLIntegrateCustomStepKernel::~OpenCLIntegrateCustomStepKernel() {
if (globalValues != NULL) if (globalValues != NULL)
delete globalValues; delete globalValues;
...@@ -6686,6 +6707,8 @@ OpenCLIntegrateCustomStepKernel::~OpenCLIntegrateCustomStepKernel() { ...@@ -6686,6 +6707,8 @@ OpenCLIntegrateCustomStepKernel::~OpenCLIntegrateCustomStepKernel() {
delete uniformRandoms; delete uniformRandoms;
if (randomSeed != NULL) if (randomSeed != NULL)
delete randomSeed; delete randomSeed;
if (perDofEnergyParamDerivs != NULL)
delete perDofEnergyParamDerivs;
if (perDofValues != NULL) if (perDofValues != NULL)
delete perDofValues; delete perDofValues;
for (map<int, OpenCLArray*>::iterator iter = savedForces.begin(); iter != savedForces.end(); ++iter) for (map<int, OpenCLArray*>::iterator iter = savedForces.begin(); iter != savedForces.end(); ++iter)
...@@ -6740,7 +6763,11 @@ string OpenCLIntegrateCustomStepKernel::createPerDofComputation(const string& va ...@@ -6740,7 +6763,11 @@ string OpenCLIntegrateCustomStepKernel::createPerDofComputation(const string& va
vector<const TabulatedFunction*> functions; vector<const TabulatedFunction*> functions;
vector<pair<string, string> > functionNames; vector<pair<string, string> > functionNames;
string tempType = (cl.getSupportsDoublePrecision() ? "double" : "float"); string tempType = (cl.getSupportsDoublePrecision() ? "double" : "float");
return cl.getExpressionUtilities().createExpressions(expressions, variables, functions, functionNames, "temp"+cl.intToString(component)+"_", tempType); vector<pair<ExpressionTreeNode, string> > variableNodes;
findExpressionsForDerivs(expr.getRootNode(), variableNodes);
for (map<string, string>::const_iterator iter = variables.begin(); iter != variables.end(); ++iter)
variableNodes.push_back(make_pair(ExpressionTreeNode(new Operation::Variable(iter->first)), iter->second));
return cl.getExpressionUtilities().createExpressions(expressions, variableNodes, functions, functionNames, "temp"+cl.intToString(component)+"_", tempType);
} }
void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid) { void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid) {
...@@ -6782,7 +6809,7 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context ...@@ -6782,7 +6809,7 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context
blockEnd[blockEnd[step]] = step; // Record where to branch back to. blockEnd[blockEnd[step]] = step; // Record where to branch back to.
if (stepType[step] == CustomIntegrator::ComputeGlobal || stepType[step] == CustomIntegrator::IfBlockStart || stepType[step] == CustomIntegrator::WhileBlockStart) if (stepType[step] == CustomIntegrator::ComputeGlobal || stepType[step] == CustomIntegrator::IfBlockStart || stepType[step] == CustomIntegrator::WhileBlockStart)
for (int i = 0; i < (int) expression[step].size(); i++) for (int i = 0; i < (int) expression[step].size(); i++)
globalExpressions[step].push_back(expression[step][i].createCompiledExpression()); globalExpressions[step].push_back(ParsedExpression(replaceDerivFunctions(expression[step][i].getRootNode(), context)).createCompiledExpression());
} }
for (int step = 0; step < numSteps; step++) { for (int step = 0; step < numSteps; step++) {
for (int i = 0; i < (int) globalExpressions[step].size(); i++) for (int i = 0; i < (int) globalExpressions[step].size(); i++)
...@@ -6845,6 +6872,10 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context ...@@ -6845,6 +6872,10 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context
globalValuesDouble[parameterVariableIndex[i]] = value; globalValuesDouble[parameterVariableIndex[i]] = value;
expressionSet.setVariable(parameterVariableIndex[i], value); expressionSet.setVariable(parameterVariableIndex[i], value);
} }
int numContextParams = context.getParameters().size();
localPerDofEnergyParamDerivsFloat.resize(numContextParams);
localPerDofEnergyParamDerivsDouble.resize(numContextParams);
perDofEnergyParamDerivs = new OpenCLArray(cl, max(1, numContextParams), elementSize, "perDofEnergyParamDerivs");
// Record information about the targets of steps that will be stored in global variables. // Record information about the targets of steps that will be stored in global variables.
...@@ -6993,6 +7024,7 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context ...@@ -6993,6 +7024,7 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context
kernel.setArg<cl::Buffer>(index++, globalValues->getDeviceBuffer()); kernel.setArg<cl::Buffer>(index++, globalValues->getDeviceBuffer());
kernel.setArg<cl::Buffer>(index++, sumBuffer->getDeviceBuffer()); kernel.setArg<cl::Buffer>(index++, sumBuffer->getDeviceBuffer());
index += 4; index += 4;
kernel.setArg<cl::Buffer>(index++, perDofEnergyParamDerivs->getDeviceBuffer());
for (int i = 0; i < (int) perDofValues->getBuffers().size(); i++) for (int i = 0; i < (int) perDofValues->getBuffers().size(); i++)
kernel.setArg<cl::Memory>(index++, perDofValues->getBuffers()[i].getMemory()); kernel.setArg<cl::Memory>(index++, perDofValues->getBuffers()[i].getMemory());
if (stepType[step] == CustomIntegrator::ComputeSum) { if (stepType[step] == CustomIntegrator::ComputeSum) {
...@@ -7086,6 +7118,7 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context ...@@ -7086,6 +7118,7 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context
kineticEnergyKernel.setArg<cl_double>(index++, 0.0); kineticEnergyKernel.setArg<cl_double>(index++, 0.0);
else else
kineticEnergyKernel.setArg<cl_float>(index++, 0.0f); kineticEnergyKernel.setArg<cl_float>(index++, 0.0f);
kineticEnergyKernel.setArg<cl::Buffer>(index++, perDofEnergyParamDerivs->getDeviceBuffer());
for (int i = 0; i < (int) perDofValues->getBuffers().size(); i++) for (int i = 0; i < (int) perDofValues->getBuffers().size(); i++)
kineticEnergyKernel.setArg<cl::Memory>(index++, perDofValues->getBuffers()[i].getMemory()); kineticEnergyKernel.setArg<cl::Memory>(index++, perDofValues->getBuffers()[i].getMemory());
keNeedsForce = usesVariable(keExpression, "f"); keNeedsForce = usesVariable(keExpression, "f");
...@@ -7121,6 +7154,47 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context ...@@ -7121,6 +7154,47 @@ void OpenCLIntegrateCustomStepKernel::prepareForComputation(ContextImpl& context
} }
} }
ExpressionTreeNode OpenCLIntegrateCustomStepKernel::replaceDerivFunctions(const ExpressionTreeNode& node, ContextImpl& context) {
// This is called recursively to identify calls to the deriv() function inside global expressions,
// and replace them with a custom function that returns the correct value.
const Operation& op = node.getOperation();
if (op.getId() == Operation::CUSTOM && op.getName() == "deriv") {
string param = node.getChildren()[1].getOperation().getName();
if (context.getParameters().find(param) == context.getParameters().end())
throw OpenMMException("The second argument to deriv() must be a context parameter");
needsEnergyParamDerivs = true;
return ExpressionTreeNode(new Operation::Custom("deriv", new DerivFunction(energyParamDerivs, param)));
}
else {
vector<ExpressionTreeNode> children;
for (int i = 0; i < (int) node.getChildren().size(); i++)
children.push_back(replaceDerivFunctions(node.getChildren()[i], context));
return ExpressionTreeNode(op.clone(), children);
}
}
void OpenCLIntegrateCustomStepKernel::findExpressionsForDerivs(const ExpressionTreeNode& node, vector<pair<ExpressionTreeNode, string> >& variableNodes) {
// This is called recursively to identify calls to the deriv() function inside per-DOF expressions,
// and record the code to replace them with.
const Operation& op = node.getOperation();
if (op.getId() == Operation::CUSTOM && op.getName() == "deriv") {
string param = node.getChildren()[1].getOperation().getName();
int index;
for (index = 0; index < perDofEnergyParamDerivNames.size() && param != perDofEnergyParamDerivNames[index]; index++)
;
if (index == perDofEnergyParamDerivNames.size())
perDofEnergyParamDerivNames.push_back(param);
variableNodes.push_back(make_pair(node, "energyParamDerivs["+cl.intToString(index)+"]"));
needsEnergyParamDerivs = true;
}
else {
for (int i = 0; i < (int) node.getChildren().size(); i++)
findExpressionsForDerivs(node.getChildren()[i], variableNodes);
}
}
void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid) { void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegrator& integrator, bool& forcesAreValid) {
prepareForComputation(context, integrator, forcesAreValid); prepareForComputation(context, integrator, forcesAreValid);
OpenCLIntegrationUtilities& integration = cl.getIntegrationUtilities(); OpenCLIntegrationUtilities& integration = cl.getIntegrationUtilities();
...@@ -7156,6 +7230,21 @@ void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr ...@@ -7156,6 +7230,21 @@ void OpenCLIntegrateCustomStepKernel::execute(ContextImpl& context, CustomIntegr
else { else {
recordChangedParameters(context); recordChangedParameters(context);
energy = context.calcForcesAndEnergy(computeForce, computeEnergy, forceGroupFlags[step]); energy = context.calcForcesAndEnergy(computeForce, computeEnergy, forceGroupFlags[step]);
if (needsEnergyParamDerivs) {
context.getEnergyParameterDerivatives(energyParamDerivs);
if (perDofEnergyParamDerivNames.size() > 0) {
if (cl.getUseDoublePrecision() || cl.getUseMixedPrecision()) {
for (int i = 0; i < perDofEnergyParamDerivNames.size(); i++)
localPerDofEnergyParamDerivsDouble[i] = energyParamDerivs[perDofEnergyParamDerivNames[i]];
perDofEnergyParamDerivs->upload(localPerDofEnergyParamDerivsDouble);
}
else {
for (int i = 0; i < perDofEnergyParamDerivNames.size(); i++)
localPerDofEnergyParamDerivsFloat[i] = (float) energyParamDerivs[perDofEnergyParamDerivNames[i]];
perDofEnergyParamDerivs->upload(localPerDofEnergyParamDerivsFloat);
}
}
}
forcesAreValid = true; forcesAreValid = true;
} }
} }
......
...@@ -25,7 +25,8 @@ void storePos(__global real4* restrict posq, __global real4* restrict posqCorrec ...@@ -25,7 +25,8 @@ void storePos(__global real4* restrict posq, __global real4* restrict posqCorrec
__kernel void computePerDof(__global real4* restrict posq, __global real4* restrict posqCorrection, __global mixed4* restrict posDelta, __kernel void computePerDof(__global real4* restrict posq, __global real4* restrict posqCorrection, __global mixed4* restrict posDelta,
__global mixed4* restrict velm, __global const real4* restrict force, __global const mixed2* restrict dt, __global const mixed* restrict globals, __global mixed4* restrict velm, __global const real4* restrict force, __global const mixed2* restrict dt, __global const mixed* restrict globals,
__global mixed* restrict sum, __global const float4* restrict gaussianValues, unsigned int gaussianBaseIndex, __global const float4* restrict uniformValues, const real energy __global mixed* restrict sum, __global const float4* restrict gaussianValues, unsigned int gaussianBaseIndex, __global const float4* restrict uniformValues,
const real energy, __global mixed* restrict energyParamDerivs
PARAMETER_ARGUMENTS) { PARAMETER_ARGUMENTS) {
mixed stepSize = dt[0].y; mixed stepSize = dt[0].y;
int index = get_global_id(0); int index = get_global_id(0);
......
...@@ -810,20 +810,20 @@ void testEnergyParameterDerivatives() { ...@@ -810,20 +810,20 @@ void testEnergyParameterDerivatives() {
CustomIntegrator integrator(0.1); CustomIntegrator integrator(0.1);
integrator.addGlobalVariable("dEdK", 0.0); integrator.addGlobalVariable("dEdK", 0.0);
integrator.addGlobalVariable("dEdr0", 0.0); integrator.addGlobalVariable("dEdr0", 0.0);
integrator.addGlobalVariable("dEdtheta0", 0.0); integrator.addPerDofVariable("dEdtheta0", 0.0);
integrator.addGlobalVariable("dEdK_0", 0.0); integrator.addGlobalVariable("dEdK_0", 0.0);
integrator.addGlobalVariable("dEdr0_0", 0.0); integrator.addPerDofVariable("dEdr0_0", 0.0);
integrator.addGlobalVariable("dEdtheta0_0", 0.0); integrator.addGlobalVariable("dEdtheta0_0", 0.0);
integrator.addGlobalVariable("dEdK_1", 0.0); integrator.addPerDofVariable("dEdK_1", 0.0);
integrator.addGlobalVariable("dEdr0_1", 0.0); integrator.addGlobalVariable("dEdr0_1", 0.0);
integrator.addGlobalVariable("dEdtheta0_1", 0.0); integrator.addGlobalVariable("dEdtheta0_1", 0.0);
integrator.addComputeGlobal("dEdK", "deriv(energy, K)"); integrator.addComputeGlobal("dEdK", "deriv(energy, K)");
integrator.addComputeGlobal("dEdr0", "deriv(energy, r0)"); integrator.addComputeGlobal("dEdr0", "deriv(energy, r0)");
integrator.addComputeGlobal("dEdtheta0", "deriv(energy, theta0)"); integrator.addComputePerDof("dEdtheta0", "deriv(energy, theta0)");
integrator.addComputeGlobal("dEdK_0", "deriv(energy0, K)"); integrator.addComputeGlobal("dEdK_0", "deriv(energy0, K)");
integrator.addComputeGlobal("dEdr0_0", "deriv(energy0, r0)"); integrator.addComputePerDof("dEdr0_0", "deriv(energy0, r0)");
integrator.addComputeGlobal("dEdtheta0_0", "deriv(energy0, theta0)"); integrator.addComputeGlobal("dEdtheta0_0", "deriv(energy0, theta0)");
integrator.addComputeGlobal("dEdK_1", "deriv(energy1, K)"); integrator.addComputePerDof("dEdK_1", "deriv(energy1, K)");
integrator.addComputeGlobal("dEdr0_1", "deriv(energy1, r0)"); integrator.addComputeGlobal("dEdr0_1", "deriv(energy1, r0)");
integrator.addComputeGlobal("dEdtheta0_1", "deriv(energy1, theta0)"); integrator.addComputeGlobal("dEdtheta0_1", "deriv(energy1, theta0)");
...@@ -839,19 +839,23 @@ void testEnergyParameterDerivatives() { ...@@ -839,19 +839,23 @@ void testEnergyParameterDerivatives() {
// Check the results. // Check the results.
integrator.step(1); integrator.step(1);
vector<Vec3> values;
double dEdK_0 = (1.0-1.5)*(1.0-1.5); double dEdK_0 = (1.0-1.5)*(1.0-1.5);
double dEdK_1 = (M_PI/2-M_PI/3)*(M_PI/2-M_PI/3); double dEdK_1 = (M_PI/2-M_PI/3)*(M_PI/2-M_PI/3);
ASSERT_EQUAL_TOL(dEdK_0, integrator.getGlobalVariableByName("dEdK_0"), 1e-5); ASSERT_EQUAL_TOL(dEdK_0, integrator.getGlobalVariableByName("dEdK_0"), 1e-5);
ASSERT_EQUAL_TOL(dEdK_1, integrator.getGlobalVariableByName("dEdK_1"), 1e-5); integrator.getPerDofVariableByName("dEdK_1", values);
ASSERT_EQUAL_TOL(dEdK_1, values[0][2], 1e-5);
ASSERT_EQUAL_TOL(dEdK_0+dEdK_1, integrator.getGlobalVariableByName("dEdK"), 1e-5); ASSERT_EQUAL_TOL(dEdK_0+dEdK_1, integrator.getGlobalVariableByName("dEdK"), 1e-5);
double dEdr0 = -2.0*2.0*(1.0-1.5); double dEdr0 = -2.0*2.0*(1.0-1.5);
ASSERT_EQUAL_TOL(dEdr0, integrator.getGlobalVariableByName("dEdr0_0"), 1e-5); integrator.getPerDofVariableByName("dEdr0_0", values);
ASSERT_EQUAL_TOL(dEdr0, values[1][0], 1e-5);
ASSERT_EQUAL_TOL(0.0, integrator.getGlobalVariableByName("dEdr0_1"), 1e-5); ASSERT_EQUAL_TOL(0.0, integrator.getGlobalVariableByName("dEdr0_1"), 1e-5);
ASSERT_EQUAL_TOL(dEdr0, integrator.getGlobalVariableByName("dEdr0"), 1e-5); ASSERT_EQUAL_TOL(dEdr0, integrator.getGlobalVariableByName("dEdr0"), 1e-5);
double dEdtheta0 = -2.0*2.0*(M_PI/2-M_PI/3); double dEdtheta0 = -2.0*2.0*(M_PI/2-M_PI/3);
ASSERT_EQUAL_TOL(0.0, integrator.getGlobalVariableByName("dEdtheta0_0"), 1e-5); ASSERT_EQUAL_TOL(0.0, integrator.getGlobalVariableByName("dEdtheta0_0"), 1e-5);
ASSERT_EQUAL_TOL(dEdtheta0, integrator.getGlobalVariableByName("dEdtheta0_1"), 1e-5); ASSERT_EQUAL_TOL(dEdtheta0, integrator.getGlobalVariableByName("dEdtheta0_1"), 1e-5);
ASSERT_EQUAL_TOL(dEdtheta0, integrator.getGlobalVariableByName("dEdtheta0"), 1e-5); integrator.getPerDofVariableByName("dEdtheta0", values);
ASSERT_EQUAL_TOL(dEdtheta0, values[2][1], 1e-5);
} }
void runPlatformTests(); void runPlatformTests();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment