Unverified Commit 1c528ca8 authored by Evan Pretti's avatar Evan Pretti Committed by GitHub
Browse files

Take line search energy difference on CPU before reducing precision (#5242)

parent 26df7a87
......@@ -98,7 +98,7 @@ private:
int maxIterations;
MinimizationReporter* reporter;
double kRestraint, energy;
double kRestraint, energy, energyStart;
bool largeGrad;
ComputeArray constraintIndices, constraintDistances;
......
......@@ -173,8 +173,8 @@ void CommonMinimizeKernel::setup(ContextImpl& context) {
returnFlag.initialize<int>(cc, 1, "returnFlag");
returnValue.initialize(cc, 1, elementSize, "returnValue");
gradNorm.initialize(cc, 1, elementSize, "gradNorm");
lineSearchData.initialize(cc, 4, elementSize, "lineSearchData");
lineSearchDataBackup.initialize(cc, 4, elementSize, "lineSearchDataBackup");
lineSearchData.initialize(cc, 3, elementSize, "lineSearchData");
lineSearchDataBackup.initialize(cc, 3, elementSize, "lineSearchDataBackup");
// Compile kernels and set arguments.
......@@ -341,7 +341,6 @@ void CommonMinimizeKernel::setup(ContextImpl& context) {
lineSearchSetupKernel->addArg(gradNorm);
lineSearchSetupKernel->addArg(lineSearchData);
lineSearchSetupKernel->addArg(numVariables);
lineSearchSetupKernel->addArg(); // energyStart
lineSearchStepKernel = program->createKernel("lineSearchStep");
lineSearchStepKernel->addArg(x);
......@@ -405,12 +404,7 @@ void CommonMinimizeKernel::lbfgs(ContextImpl& context) {
for (int iteration = 1, end = 0;;) {
// Prepare for a line search.
if (mixedIsDouble) {
lineSearchSetupKernel->setArg(9, energy);
}
else {
lineSearchSetupKernel->setArg(9, (float) energy);
}
energyStart = energy;
lineSearchSetupKernel->execute(numVariables);
// Take line search steps.
......@@ -746,11 +740,10 @@ double CommonMinimizeKernel::downloadGradNormSync() {
void CommonMinimizeKernel::runLineSearchKernels() {
if (mixedIsDouble) {
lineSearchDotKernel->setArg(6, isfinite(energy) ? energy : (double) std::numeric_limits<float>::max());
lineSearchDotKernel->setArg(6, isfinite(energy) ? energy - energyStart : (double) std::numeric_limits<float>::max());
}
else {
float hostEnergy = (float) energy;
lineSearchDotKernel->setArg(6, isfinite(hostEnergy) ? hostEnergy : std::numeric_limits<float>::max());
lineSearchDotKernel->setArg(6, isfinite((float) energy) ? (float) (energy - energyStart) : std::numeric_limits<float>::max());
}
lineSearchDotKernel->execute(numVariables);
lineSearchContinueKernel->execute(1);
......
#define LS_DOT_START 0
#define LS_DOT 1
#define LS_ENERGY 2
#define LS_STEP 3
#define LS_STEP 2
#define LS_FAIL 0
#define LS_SUCCEED 1
......@@ -668,8 +667,7 @@ KERNEL void lineSearchSetup(
GLOBAL int* RESTRICT returnFlag,
GLOBAL mixed* RESTRICT gradNorm,
GLOBAL mixed* RESTRICT lineSearchData,
const int numVariables,
const mixed energyStart
const int numVariables
) {
LOCAL volatile mixed temp[TEMP_SIZE];
......@@ -694,7 +692,6 @@ KERNEL void lineSearchSetup(
if (GLOBAL_ID == 0) {
*returnFlag = LS_CONTINUE;
*gradNorm = 0;
lineSearchData[LS_ENERGY] = energyStart;
}
}
......@@ -758,7 +755,6 @@ KERNEL void lineSearchStep(
lineSearchDataBackup[LS_DOT_START] = lineSearchData[LS_DOT_START];
lineSearchDataBackup[LS_DOT] = lineSearchData[LS_DOT] = 0;
lineSearchDataBackup[LS_ENERGY] = lineSearchData[LS_ENERGY];
lineSearchDataBackup[LS_STEP] = lineSearchData[LS_STEP];
}
}
......@@ -770,7 +766,7 @@ KERNEL void lineSearchDot(
GLOBAL int* RESTRICT returnFlag,
GLOBAL const mixed* RESTRICT returnValue,
const int numVariables,
mixed energy
mixed deltaEnergy
) {
LOCAL volatile mixed temp[TEMP_SIZE];
......@@ -781,13 +777,13 @@ KERNEL void lineSearchDot(
// Any restraint energy in returnValue hasn't been downloaded yet to be
// passed back up in the energy parameter, so add it in here.
energy += *returnValue;
deltaEnergy += *returnValue;
// The energy may be such that we don't need to do a dot product and can
// immediately decide to scale the step, so mark this case with LS_SUCCEED.
// This will be checked in the following kernel.
if (!(FABS_MIXED(energy) < FLT_MAX) || energy > lineSearchData[LS_ENERGY] + lineSearchData[LS_STEP] * LBFGS_FTOL * lineSearchData[LS_DOT_START]) {
if (!(FABS_MIXED(deltaEnergy) < FLT_MAX) || deltaEnergy > lineSearchData[LS_STEP] * LBFGS_FTOL * lineSearchData[LS_DOT_START]) {
if (GLOBAL_ID == 0) {
*returnFlag = LS_SUCCEED;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment