Unverified Commit 1c528ca8 authored by Evan Pretti's avatar Evan Pretti Committed by GitHub
Browse files

Take line search energy difference on CPU before reducing precision (#5242)

parent 26df7a87
...@@ -98,7 +98,7 @@ private: ...@@ -98,7 +98,7 @@ private:
int maxIterations; int maxIterations;
MinimizationReporter* reporter; MinimizationReporter* reporter;
double kRestraint, energy; double kRestraint, energy, energyStart;
bool largeGrad; bool largeGrad;
ComputeArray constraintIndices, constraintDistances; ComputeArray constraintIndices, constraintDistances;
......
...@@ -173,8 +173,8 @@ void CommonMinimizeKernel::setup(ContextImpl& context) { ...@@ -173,8 +173,8 @@ void CommonMinimizeKernel::setup(ContextImpl& context) {
returnFlag.initialize<int>(cc, 1, "returnFlag"); returnFlag.initialize<int>(cc, 1, "returnFlag");
returnValue.initialize(cc, 1, elementSize, "returnValue"); returnValue.initialize(cc, 1, elementSize, "returnValue");
gradNorm.initialize(cc, 1, elementSize, "gradNorm"); gradNorm.initialize(cc, 1, elementSize, "gradNorm");
lineSearchData.initialize(cc, 4, elementSize, "lineSearchData"); lineSearchData.initialize(cc, 3, elementSize, "lineSearchData");
lineSearchDataBackup.initialize(cc, 4, elementSize, "lineSearchDataBackup"); lineSearchDataBackup.initialize(cc, 3, elementSize, "lineSearchDataBackup");
// Compile kernels and set arguments. // Compile kernels and set arguments.
...@@ -341,7 +341,6 @@ void CommonMinimizeKernel::setup(ContextImpl& context) { ...@@ -341,7 +341,6 @@ void CommonMinimizeKernel::setup(ContextImpl& context) {
lineSearchSetupKernel->addArg(gradNorm); lineSearchSetupKernel->addArg(gradNorm);
lineSearchSetupKernel->addArg(lineSearchData); lineSearchSetupKernel->addArg(lineSearchData);
lineSearchSetupKernel->addArg(numVariables); lineSearchSetupKernel->addArg(numVariables);
lineSearchSetupKernel->addArg(); // energyStart
lineSearchStepKernel = program->createKernel("lineSearchStep"); lineSearchStepKernel = program->createKernel("lineSearchStep");
lineSearchStepKernel->addArg(x); lineSearchStepKernel->addArg(x);
...@@ -405,12 +404,7 @@ void CommonMinimizeKernel::lbfgs(ContextImpl& context) { ...@@ -405,12 +404,7 @@ void CommonMinimizeKernel::lbfgs(ContextImpl& context) {
for (int iteration = 1, end = 0;;) { for (int iteration = 1, end = 0;;) {
// Prepare for a line search. // Prepare for a line search.
if (mixedIsDouble) { energyStart = energy;
lineSearchSetupKernel->setArg(9, energy);
}
else {
lineSearchSetupKernel->setArg(9, (float) energy);
}
lineSearchSetupKernel->execute(numVariables); lineSearchSetupKernel->execute(numVariables);
// Take line search steps. // Take line search steps.
...@@ -746,11 +740,10 @@ double CommonMinimizeKernel::downloadGradNormSync() { ...@@ -746,11 +740,10 @@ double CommonMinimizeKernel::downloadGradNormSync() {
void CommonMinimizeKernel::runLineSearchKernels() { void CommonMinimizeKernel::runLineSearchKernels() {
if (mixedIsDouble) { if (mixedIsDouble) {
lineSearchDotKernel->setArg(6, isfinite(energy) ? energy : (double) std::numeric_limits<float>::max()); lineSearchDotKernel->setArg(6, isfinite(energy) ? energy - energyStart : (double) std::numeric_limits<float>::max());
} }
else { else {
float hostEnergy = (float) energy; lineSearchDotKernel->setArg(6, isfinite((float) energy) ? (float) (energy - energyStart) : std::numeric_limits<float>::max());
lineSearchDotKernel->setArg(6, isfinite(hostEnergy) ? hostEnergy : std::numeric_limits<float>::max());
} }
lineSearchDotKernel->execute(numVariables); lineSearchDotKernel->execute(numVariables);
lineSearchContinueKernel->execute(1); lineSearchContinueKernel->execute(1);
......
#define LS_DOT_START 0 #define LS_DOT_START 0
#define LS_DOT 1 #define LS_DOT 1
#define LS_ENERGY 2 #define LS_STEP 2
#define LS_STEP 3
#define LS_FAIL 0 #define LS_FAIL 0
#define LS_SUCCEED 1 #define LS_SUCCEED 1
...@@ -668,8 +667,7 @@ KERNEL void lineSearchSetup( ...@@ -668,8 +667,7 @@ KERNEL void lineSearchSetup(
GLOBAL int* RESTRICT returnFlag, GLOBAL int* RESTRICT returnFlag,
GLOBAL mixed* RESTRICT gradNorm, GLOBAL mixed* RESTRICT gradNorm,
GLOBAL mixed* RESTRICT lineSearchData, GLOBAL mixed* RESTRICT lineSearchData,
const int numVariables, const int numVariables
const mixed energyStart
) { ) {
LOCAL volatile mixed temp[TEMP_SIZE]; LOCAL volatile mixed temp[TEMP_SIZE];
...@@ -694,7 +692,6 @@ KERNEL void lineSearchSetup( ...@@ -694,7 +692,6 @@ KERNEL void lineSearchSetup(
if (GLOBAL_ID == 0) { if (GLOBAL_ID == 0) {
*returnFlag = LS_CONTINUE; *returnFlag = LS_CONTINUE;
*gradNorm = 0; *gradNorm = 0;
lineSearchData[LS_ENERGY] = energyStart;
} }
} }
...@@ -758,7 +755,6 @@ KERNEL void lineSearchStep( ...@@ -758,7 +755,6 @@ KERNEL void lineSearchStep(
lineSearchDataBackup[LS_DOT_START] = lineSearchData[LS_DOT_START]; lineSearchDataBackup[LS_DOT_START] = lineSearchData[LS_DOT_START];
lineSearchDataBackup[LS_DOT] = lineSearchData[LS_DOT] = 0; lineSearchDataBackup[LS_DOT] = lineSearchData[LS_DOT] = 0;
lineSearchDataBackup[LS_ENERGY] = lineSearchData[LS_ENERGY];
lineSearchDataBackup[LS_STEP] = lineSearchData[LS_STEP]; lineSearchDataBackup[LS_STEP] = lineSearchData[LS_STEP];
} }
} }
...@@ -770,7 +766,7 @@ KERNEL void lineSearchDot( ...@@ -770,7 +766,7 @@ KERNEL void lineSearchDot(
GLOBAL int* RESTRICT returnFlag, GLOBAL int* RESTRICT returnFlag,
GLOBAL const mixed* RESTRICT returnValue, GLOBAL const mixed* RESTRICT returnValue,
const int numVariables, const int numVariables,
mixed energy mixed deltaEnergy
) { ) {
LOCAL volatile mixed temp[TEMP_SIZE]; LOCAL volatile mixed temp[TEMP_SIZE];
...@@ -781,13 +777,13 @@ KERNEL void lineSearchDot( ...@@ -781,13 +777,13 @@ KERNEL void lineSearchDot(
// Any restraint energy in returnValue hasn't been downloaded yet to be // Any restraint energy in returnValue hasn't been downloaded yet to be
// passed back up in the energy parameter, so add it in here. // passed back up in the energy parameter, so add it in here.
energy += *returnValue; deltaEnergy += *returnValue;
// The energy may be such that we don't need to do a dot product and can // The energy may be such that we don't need to do a dot product and can
// immediately decide to scale the step, so mark this case with LS_SUCCEED. // immediately decide to scale the step, so mark this case with LS_SUCCEED.
// This will be checked in the following kernel. // This will be checked in the following kernel.
if (!(FABS_MIXED(energy) < FLT_MAX) || energy > lineSearchData[LS_ENERGY] + lineSearchData[LS_STEP] * LBFGS_FTOL * lineSearchData[LS_DOT_START]) { if (!(FABS_MIXED(deltaEnergy) < FLT_MAX) || deltaEnergy > lineSearchData[LS_STEP] * LBFGS_FTOL * lineSearchData[LS_DOT_START]) {
if (GLOBAL_ID == 0) { if (GLOBAL_ID == 0) {
*returnFlag = LS_SUCCEED; *returnFlag = LS_SUCCEED;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment