constantPotentialCGSolver.cc 24.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
#define WARP_SIZE 32

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
    #define WARP_SHUFFLE_DOWN(local, offset) __shfl_down_sync(0xffffffff, local, offset)
#elif defined(USE_HIP)
    #define WARP_SHUFFLE_DOWN(local, offset) __shfl_down(local, offset)
#endif

#ifdef WARP_SHUFFLE_DOWN
    #define TEMP_SIZE WARP_SIZE
#else
    #define TEMP_SIZE THREAD_BLOCK_SIZE
#endif

typedef struct {
    real gradStepSq, qStepGradStep, qStepGrad, q, qStep;
} BlockSums1;

typedef struct {
    real projGradSq, precGradStep, precGrad;
} BlockSums2;

// Sum value from each thread (using temp).  Use real type variables (float on
// single and mixed precision modes, double on double precision mode).
DEVICE real reduceReal(real value, LOCAL_ARG volatile real* temp) {
    const int thread = LOCAL_ID;
    SYNC_THREADS;
#ifdef WARP_SHUFFLE_DOWN
    const int warpCount = LOCAL_SIZE / WARP_SIZE;
    const int warp = thread / WARP_SIZE;
    const int lane = thread % WARP_SIZE;
    for (int step = WARP_SIZE / 2; step > 0; step >>= 1) {
        value += WARP_SHUFFLE_DOWN(value, step);
    }
    if (!lane) {
        temp[warp] = value;
    }
    SYNC_THREADS;
    if (!warp) {
        value = lane < warpCount ? temp[lane] : 0;
        for (int step = WARP_SIZE / 2; step > 0; step >>= 1) {
            value += WARP_SHUFFLE_DOWN(value, step);
        }
        if (!lane) {
            temp[0] = value;
        }
    }
    SYNC_THREADS;
#else
    temp[thread] = value;
    SYNC_THREADS;
    for (int step = 1; step < WARP_SIZE / 2; step <<= 1) {
        if(thread + step < LOCAL_SIZE && thread % (2 * step) == 0) {
            temp[thread] += temp[thread + step];
        }
        SYNC_WARPS;
    }
    for (int step = WARP_SIZE / 2; step < LOCAL_SIZE; step <<= 1) {
        if(thread + step < LOCAL_SIZE && thread % (2 * step) == 0) {
            temp[thread] += temp[thread + step];
        }
        SYNC_THREADS;
    }
#endif
    return temp[0];
}

// Performs the equivalent of reduceReal() on 5 values simultaneously.
DEVICE BlockSums1 reduceBlockSums1(BlockSums1 value, LOCAL_ARG BlockSums1* temp) {
    const int thread = LOCAL_ID;
    SYNC_THREADS;
#ifdef WARP_SHUFFLE_DOWN
    const int warpCount = LOCAL_SIZE / WARP_SIZE;
    const int warp = thread / WARP_SIZE;
    const int lane = thread % WARP_SIZE;
    for (int step = WARP_SIZE / 2; step > 0; step >>= 1) {
        value.gradStepSq += WARP_SHUFFLE_DOWN(value.gradStepSq, step);
        value.qStepGradStep += WARP_SHUFFLE_DOWN(value.qStepGradStep, step);
        value.qStepGrad += WARP_SHUFFLE_DOWN(value.qStepGrad, step);
        value.q += WARP_SHUFFLE_DOWN(value.q, step);
        value.qStep += WARP_SHUFFLE_DOWN(value.qStep, step);
    }
    if (!lane) {
        temp[warp] = value;
    }
    SYNC_THREADS;
    if (!warp) {
        value.gradStepSq = value.qStepGradStep = value.qStepGrad = value.q = value.qStep = 0;
        if (lane < warpCount) {
            value = temp[lane];
        }
        for (int step = WARP_SIZE / 2; step > 0; step >>= 1) {
            value.gradStepSq += WARP_SHUFFLE_DOWN(value.gradStepSq, step);
            value.qStepGradStep += WARP_SHUFFLE_DOWN(value.qStepGradStep, step);
            value.qStepGrad += WARP_SHUFFLE_DOWN(value.qStepGrad, step);
            value.q += WARP_SHUFFLE_DOWN(value.q, step);
            value.qStep += WARP_SHUFFLE_DOWN(value.qStep, step);
        }
        if (!lane) {
            temp[0] = value;
        }
    }
    SYNC_THREADS;
#else
    temp[thread] = value;
    SYNC_THREADS;
    for (int step = 1; step < WARP_SIZE / 2; step <<= 1) {
        if(thread + step < LOCAL_SIZE && thread % (2 * step) == 0) {
            temp[thread].gradStepSq += temp[thread + step].gradStepSq;
            temp[thread].qStepGradStep += temp[thread + step].qStepGradStep;
            temp[thread].qStepGrad += temp[thread + step].qStepGrad;
            temp[thread].q += temp[thread + step].q;
            temp[thread].qStep += temp[thread + step].qStep;
        }
        SYNC_WARPS;
    }
    for (int step = WARP_SIZE / 2; step < LOCAL_SIZE; step <<= 1) {
        if(thread + step < LOCAL_SIZE && thread % (2 * step) == 0) {
            temp[thread].gradStepSq += temp[thread + step].gradStepSq;
            temp[thread].qStepGradStep += temp[thread + step].qStepGradStep;
            temp[thread].qStepGrad += temp[thread + step].qStepGrad;
            temp[thread].q += temp[thread + step].q;
            temp[thread].qStep += temp[thread + step].qStep;
        }
        SYNC_THREADS;
    }
#endif
    return temp[0];
}

// Performs the equivalent of reduceReal() on 3 values simultaneously.
DEVICE BlockSums2 reduceBlockSums2(BlockSums2 value, LOCAL_ARG BlockSums2* temp) {
    const int thread = LOCAL_ID;
    SYNC_THREADS;
#ifdef WARP_SHUFFLE_DOWN
    const int warpCount = LOCAL_SIZE / WARP_SIZE;
    const int warp = thread / WARP_SIZE;
    const int lane = thread % WARP_SIZE;
    for (int step = WARP_SIZE / 2; step > 0; step >>= 1) {
        value.projGradSq += WARP_SHUFFLE_DOWN(value.projGradSq, step);
        value.precGradStep += WARP_SHUFFLE_DOWN(value.precGradStep, step);
        value.precGrad += WARP_SHUFFLE_DOWN(value.precGrad, step);
    }
    if (!lane) {
        temp[warp] = value;
    }
    SYNC_THREADS;
    if (!warp) {
        value.projGradSq = value.precGradStep = value.precGrad = 0;
        if (lane < warpCount) {
            value = temp[lane];
        }
        for (int step = WARP_SIZE / 2; step > 0; step >>= 1) {
            value.projGradSq += WARP_SHUFFLE_DOWN(value.projGradSq, step);
            value.precGradStep += WARP_SHUFFLE_DOWN(value.precGradStep, step);
            value.precGrad += WARP_SHUFFLE_DOWN(value.precGrad, step);
        }
        if (!lane) {
            temp[0] = value;
        }
    }
    SYNC_THREADS;
#else
    temp[thread] = value;
    SYNC_THREADS;
    for (int step = 1; step < WARP_SIZE / 2; step <<= 1) {
        if(thread + step < LOCAL_SIZE && thread % (2 * step) == 0) {
            temp[thread].projGradSq += temp[thread + step].projGradSq;
            temp[thread].precGradStep += temp[thread + step].precGradStep;
            temp[thread].precGrad += temp[thread + step].precGrad;
        }
        SYNC_WARPS;
    }
    for (int step = WARP_SIZE / 2; step < LOCAL_SIZE; step <<= 1) {
        if(thread + step < LOCAL_SIZE && thread % (2 * step) == 0) {
            temp[thread].projGradSq += temp[thread + step].projGradSq;
            temp[thread].precGradStep += temp[thread + step].precGradStep;
            temp[thread].precGrad += temp[thread + step].precGrad;
        }
        SYNC_THREADS;
    }
#endif
    return temp[0];
}

// We need more than single precision for accumulation regardless of the mode
// selected, so use double if double precision is supported, and otherwise use
// double-float arithmetic.  In the latter case, use float2 with .x storing the
// "high" part and .y storing the "low" part of each double-float number.
#ifdef SUPPORTS_DOUBLE_PRECISION

#define ACCUM double
#define ACCUM_ZERO 0.0

// Perform accum = accum + real.
#define ACCUM_ADD(x, y) ((x) + (ACCUM) (y))
// Perform real = real + accum.
#define ACCUM_APPLY(x, y) ((real) ((ACCUM) (x) + (y)))
// Perform real = accum * real.
#define ACCUM_MUL(x, y) ((real) ((x) * (ACCUM) (y)))
// Perform accum = (accum * real) + accum.
#define ACCUM_MUL_ADD(x, y, z) (((x) * (ACCUM) (y)) + (z))
// Perform real = (real + accum) * accum.
#define ACCUM_ADD_MUL(x, y, z) ((real) (((ACCUM) (x) + (y)) * (z)))

// Sum value from each thread (using temp) and return (sum + offset) * scale.
DEVICE ACCUM reduceAccum(ACCUM value, LOCAL_ARG volatile ACCUM* temp, real offset, real scale) {
    const int thread = LOCAL_ID;
    SYNC_THREADS;
#ifdef WARP_SHUFFLE_DOWN
    const int warpCount = LOCAL_SIZE / WARP_SIZE;
    const int warp = thread / WARP_SIZE;
    const int lane = thread % WARP_SIZE;
    for (int step = WARP_SIZE / 2; step > 0; step >>= 1) {
        value += WARP_SHUFFLE_DOWN(value, step);
    }
    if (!lane) {
        temp[warp] = value;
    }
    SYNC_THREADS;
    if (!warp) {
        value = lane < warpCount ? temp[lane] : 0;
        for (int step = WARP_SIZE / 2; step > 0; step >>= 1) {
            value += WARP_SHUFFLE_DOWN(value, step);
        }
        if (!lane) {
            temp[0] = value;
        }
    }
    SYNC_THREADS;
#else
    temp[thread] = value;
    SYNC_THREADS;
    for (int step = 1; step < WARP_SIZE / 2; step <<= 1) {
        if(thread + step < LOCAL_SIZE && thread % (2 * step) == 0) {
            temp[thread] += temp[thread + step];
        }
        SYNC_WARPS;
    }
    for (int step = WARP_SIZE / 2; step < LOCAL_SIZE; step <<= 1) {
        if(thread + step < LOCAL_SIZE && thread % (2 * step) == 0) {
            temp[thread] += temp[thread + step];
        }
        SYNC_THREADS;
    }
#endif
    return (temp[0] + offset) * scale;
}

#else

#define ACCUM float2
#define ACCUM_ZERO make_float2(0.0f, 0.0f)

#define ACCUM_ADD(x, y) compensatedAdd2(x, y)
#define ACCUM_APPLY(x, y) compensatedAdd1(y, x)
#define ACCUM_MUL(x, y) compensatedMultiply1(x, y)
#define ACCUM_MUL_ADD(x, y, z) compensatedAdd3(compensatedMultiply2(x, y), z)
#define ACCUM_ADD_MUL(x, y, z) compensatedMultiply3(compensatedAdd2(y, x), z)

// For details of the compensated addition and multiplication implemented, see
// Joldes et al., ACM Trans. Math. Softw. 2017, 44, 15res (DOI: 10.1145/3121432).

// float + float -> float2, only valid if the floating-point exponent of a is
// not less than that of b.
DEVICE inline float2 compensatedAddKernel1(float a, float b) {
    float s = a + b;
    return make_float2(s, b - (s - a));
}

// float + float -> float2, valid for any inputs.
DEVICE inline float2 compensatedAddKernel2(float a, float b) {
    float s = a + b;
    float c = s - b;
    float d = s - c;
    return make_float2(s, (a - c) + (b - d));
}

// float * float -> float2.
DEVICE inline float2 compensatedMultiplyKernel(float a, float b) {
    float c = a * b;
    return make_float2(c, FMA(a, b, -c));
}

// float2 + float -> float.  Like compensatedAdd2, but only computes the high
// part of the result.
DEVICE inline float compensatedAdd1(float2 x, float y) {
    float2 s = compensatedAddKernel2(x.x, y);
    return s.x + (x.y + s.y);
}

// float2 + float -> float2, with a relative error of 2^-47.
DEVICE inline float2 compensatedAdd2(float2 x, float y) {
    float2 s = compensatedAddKernel2(x.x, y);
    float v = x.y + s.y;
    return compensatedAddKernel1(s.x, v);
}

// float2 + float2 -> float2, with a relative error of 2^-46.
DEVICE inline float2 compensatedAdd3(float2 x, float2 y) {
    float2 s = compensatedAddKernel2(x.x, y.x);
    float2 t = compensatedAddKernel2(x.y, y.y);
    float c = s.y + t.x;
    float2 v = compensatedAddKernel1(s.x, c);
    float w = t.y + v.y;
    return compensatedAddKernel1(v.x, w);
}

// float2 * float -> float.  Like compensatedMultiply2, but only computes the
// high part of the result.
DEVICE inline float compensatedMultiply1(float2 x, float y) {
    float c = x.x * y;
    return c + (FMA(x.x, y, -c) + x.y * y);
}

// float2 * float -> float2, with a relative error of 2^-47.
DEVICE inline float2 compensatedMultiply2(float2 x, float y) {
    float c = x.x * y;
    return compensatedAddKernel1(c, FMA(x.x, y, -c) + x.y * y);
}

// float2 * float2 -> float.
DEVICE inline float compensatedMultiply3(float2 x, float2 y) {
    float2 c = compensatedMultiplyKernel(x.x, y.x);
    return c.x + (c.y + FMA(x.y, y.x, FMA(x.x, y.y, x.y * y.y)));
}

// Sum value from each thread (using temp) and return (sum + offset) * scale.
DEVICE ACCUM reduceAccum(ACCUM value, LOCAL_ARG volatile ACCUM* temp, real offset, real scale) {
    const int thread = LOCAL_ID;
    SYNC_THREADS;
#ifdef WARP_SHUFFLE_DOWN
    const int warpCount = LOCAL_SIZE / WARP_SIZE;
    const int warp = thread / WARP_SIZE;
    const int lane = thread % WARP_SIZE;
    for (int step = WARP_SIZE / 2; step > 0; step >>= 1) {
        value = compensatedAdd3(value, WARP_SHUFFLE_DOWN(value, step));
    }
    if (!lane) {
        temp[warp] = value;
    }
    SYNC_THREADS;
    if (!warp) {
        value = lane < warpCount ? temp[lane] : ACCUM_ZERO;
        for (int step = WARP_SIZE / 2; step > 0; step >>= 1) {
            value = compensatedAdd3(value, WARP_SHUFFLE_DOWN(value, step));
        }
        if (!lane) {
            temp[0] = value;
        }
    }
    SYNC_THREADS;
#else
    temp[thread] = value;
    SYNC_THREADS;
    for (int step = 1; step < WARP_SIZE / 2; step <<= 1) {
        if(thread + step < LOCAL_SIZE && thread % (2 * step) == 0) {
            temp[thread] = compensatedAdd3(temp[thread], temp[thread + step]);
        }
        SYNC_WARPS;
    }
    for (int step = WARP_SIZE / 2; step < LOCAL_SIZE; step <<= 1) {
        if(thread + step < LOCAL_SIZE && thread % (2 * step) == 0) {
            temp[thread] = compensatedAdd3(temp[thread], temp[thread + step]);
        }
        SYNC_THREADS;
    }
#endif
    return compensatedMultiply2(compensatedAdd2(temp[0], offset), scale);
}

#endif

KERNEL void solveInitializeStep1(GLOBAL real* RESTRICT electrodeCharges, GLOBAL real* RESTRICT qLast
#ifdef USE_CHARGE_CONSTRAINT
    , real chargeTarget
#endif
) {
    // This kernel expects to be executed in a single thread block.

#ifdef USE_CHARGE_CONSTRAINT
    LOCAL volatile ACCUM tempAccum[TEMP_SIZE];
#endif

    // Set initial guess charges as linear extrapolations from the current and
    // previous charges fed through the solver, and save the current charges as
    // the previous charges.
    for (int ii = LOCAL_ID; ii < NUM_ELECTRODE_PARTICLES; ii += LOCAL_SIZE) {
        const real qGuess = electrodeCharges[ii];
        electrodeCharges[ii] = 2 * qGuess - qLast[ii];
        qLast[ii] = qGuess;
    }

#ifdef USE_CHARGE_CONSTRAINT
    // Ensure that initial guess charges satisfy the constraint.
    ACCUM offsetAccum = ACCUM_ZERO;
    for (int ii = LOCAL_ID; ii < NUM_ELECTRODE_PARTICLES; ii += LOCAL_SIZE) {
        offsetAccum = ACCUM_ADD(offsetAccum, -electrodeCharges[ii]);
    }
    const ACCUM offset = reduceAccum(offsetAccum, tempAccum, chargeTarget, 1 / (real) NUM_ELECTRODE_PARTICLES);
    for (int ii = LOCAL_ID; ii < NUM_ELECTRODE_PARTICLES; ii += LOCAL_SIZE) {
        electrodeCharges[ii] = ACCUM_APPLY(electrodeCharges[ii], offset);
    }
#endif
}

KERNEL void solveInitializeStep2(GLOBAL real* RESTRICT chargeDerivatives, GLOBAL real* RESTRICT grad, GLOBAL real* RESTRICT projGrad, GLOBAL int* RESTRICT convergedResult) {
    // This kernel expects to be executed in a single thread block.

    LOCAL volatile real temp[TEMP_SIZE];
#ifdef USE_CHARGE_CONSTRAINT
    LOCAL volatile ACCUM tempAccum[TEMP_SIZE];
#endif

    for (int ii = LOCAL_ID; ii < NUM_ELECTRODE_PARTICLES; ii += LOCAL_SIZE) {
        grad[ii] = chargeDerivatives[ii];
    }

#ifdef USE_CHARGE_CONSTRAINT
    // Project the initial gradient without preconditioning.
    ACCUM offsetAccum = ACCUM_ZERO;
    for (int ii = LOCAL_ID; ii < NUM_ELECTRODE_PARTICLES; ii += LOCAL_SIZE) {
        offsetAccum = ACCUM_ADD(offsetAccum, grad[ii]);
    }
    const ACCUM offset = reduceAccum(offsetAccum, tempAccum, 0, -1 / (real) NUM_ELECTRODE_PARTICLES);
    for (int ii = LOCAL_ID; ii < NUM_ELECTRODE_PARTICLES; ii += LOCAL_SIZE) {
        projGrad[ii] = ACCUM_APPLY(grad[ii], offset);
    }
#else
    for (int ii = LOCAL_ID; ii < NUM_ELECTRODE_PARTICLES; ii += LOCAL_SIZE) {
        projGrad[ii] = grad[ii];
    }
#endif

    // Check for convergence at the initial guess charges.
    real error = 0;
    for (int ii = LOCAL_ID; ii < NUM_ELECTRODE_PARTICLES; ii += LOCAL_SIZE) {
        error += projGrad[ii] * projGrad[ii];
    }
    error = reduceReal(error, temp);
    if (LOCAL_ID == 0) {
        convergedResult[0] = (int) (error <= ERROR_TARGET);
    }
}

KERNEL void solveInitializeStep3(GLOBAL real* RESTRICT electrodeCharges, GLOBAL real* RESTRICT chargeDerivatives, GLOBAL real* RESTRICT grad, GLOBAL real* RESTRICT projGrad,
    GLOBAL real* RESTRICT precGrad, GLOBAL real* RESTRICT qStep, GLOBAL real* RESTRICT grad0
#ifdef PRECOND_REQUESTED
    , GLOBAL ACCUM* RESTRICT precondVector, int precondActivated
#endif
) {
    // This kernel expects to be executed in a single thread block.

#if defined(PRECOND_REQUESTED) && defined(USE_CHARGE_CONSTRAINT)
    LOCAL volatile ACCUM tempAccum[TEMP_SIZE];
#endif

    for (int ii = LOCAL_ID; ii < NUM_ELECTRODE_PARTICLES; ii += LOCAL_SIZE) {
        grad0[ii] = chargeDerivatives[ii];
    }

#ifdef PRECOND_REQUESTED
    // Project the initial gradient with preconditioning.
    if (precondActivated) {
#ifdef USE_CHARGE_CONSTRAINT
        ACCUM offsetAccum = ACCUM_ZERO;
        for (int ii = LOCAL_ID; ii < NUM_ELECTRODE_PARTICLES; ii += LOCAL_SIZE) {
            offsetAccum = ACCUM_MUL_ADD(precondVector[ii], grad[ii], offsetAccum);
        }
        const ACCUM offset = reduceAccum(offsetAccum, tempAccum, 0, -1);
        for (int ii = LOCAL_ID; ii < NUM_ELECTRODE_PARTICLES; ii += LOCAL_SIZE) {
            precGrad[ii] = ACCUM_ADD_MUL(grad[ii], offset, precondVector[ii]);
        }
#else
        for (int ii = LOCAL_ID; ii < NUM_ELECTRODE_PARTICLES; ii += LOCAL_SIZE) {
            precGrad[ii] = ACCUM_MUL(precondVector[ii], grad[ii]);
        }
#endif
    }
    else {
#endif
        for (int ii = LOCAL_ID; ii < NUM_ELECTRODE_PARTICLES; ii += LOCAL_SIZE) {
            precGrad[ii] = projGrad[ii];
        }
#ifdef PRECOND_REQUESTED
    }
#endif

    // Initialize step vector for conjugate gradient iterations.
    for (int ii = LOCAL_ID; ii < NUM_ELECTRODE_PARTICLES; ii += LOCAL_SIZE) {
        electrodeCharges[ii] = qStep[ii] = -precGrad[ii];
    }
}

KERNEL void solveLoopStep1(
    GLOBAL real* RESTRICT chargeDerivatives,
    GLOBAL real* RESTRICT q,
    GLOBAL real* RESTRICT grad,
    GLOBAL real* RESTRICT qStep,
    GLOBAL real* RESTRICT gradStep,
    GLOBAL real* RESTRICT grad0,
    GLOBAL BlockSums1* RESTRICT blockSums1Buffer
) {
    // This kernel can be executed across multiple thread blocks.

    LOCAL BlockSums1 temp[TEMP_SIZE];

    for (int ii = GLOBAL_ID; ii < NUM_ELECTRODE_PARTICLES; ii += GLOBAL_SIZE) {
        gradStep[ii] = chargeDerivatives[ii] - grad0[ii];
    }

    // Reduce values within each block and store results.
    BlockSums1 blockSums1 = {0, 0, 0, 0, 0};
    for (int ii = GLOBAL_ID; ii < NUM_ELECTRODE_PARTICLES; ii += GLOBAL_SIZE) {
        blockSums1.gradStepSq += gradStep[ii] * gradStep[ii];
        blockSums1.qStepGradStep += qStep[ii] * gradStep[ii];
        blockSums1.qStepGrad += qStep[ii] * grad[ii];
        blockSums1.q += q[ii];
        blockSums1.qStep += qStep[ii];
    }
    blockSums1 = reduceBlockSums1(blockSums1, temp);

    if (LOCAL_ID == 0) {
        blockSums1Buffer[GROUP_ID] = blockSums1;
    }
}

KERNEL void solveLoopStep2(
    GLOBAL BlockSums1* RESTRICT blockSums1Buffer,
    GLOBAL int* RESTRICT convergedResult
) {
    // This kernel expects to be executed in a single thread block.

    LOCAL BlockSums1 tempSums[TEMP_SIZE];

    // Reduce values from all blocks.
    BlockSums1 blockSums1 = {0, 0, 0, 0, 0};
    for (int ii = LOCAL_ID; ii < THREAD_BLOCK_COUNT; ii += LOCAL_SIZE) {
        blockSums1.gradStepSq += blockSums1Buffer[ii].gradStepSq;
        blockSums1.qStepGradStep += blockSums1Buffer[ii].qStepGradStep;
        blockSums1.qStepGrad += blockSums1Buffer[ii].qStepGrad;
        blockSums1.q += blockSums1Buffer[ii].q;
        blockSums1.qStep += blockSums1Buffer[ii].qStep;
    }
    blockSums1 = reduceBlockSums1(blockSums1, tempSums);

    if (LOCAL_ID == 0) {
        blockSums1Buffer[0] = blockSums1;
        // If A qStep is small enough, stop to prevent, e.g., division by zero
        // in the calculation of alpha, or too large step sizes.
        convergedResult[0] = (int) (blockSums1.gradStepSq <= ERROR_TARGET);
    }
}

KERNEL void solveLoopStep3(
    GLOBAL real* RESTRICT q,
    GLOBAL real* RESTRICT grad,
    GLOBAL real* RESTRICT qStep,
    GLOBAL real* RESTRICT gradStep,
    GLOBAL BlockSums1* RESTRICT blockSums1Buffer,
    GLOBAL int* RESTRICT convergedResult
#ifdef USE_CHARGE_CONSTRAINT
    , real chargeTarget
#endif
) {
    // This kernel can be executed across multiple thread blocks.

    if (convergedResult[0] != 0) {
        return;
    }

    const BlockSums1 blockSums1 = blockSums1Buffer[0];
    const real alpha = -blockSums1.qStepGrad / blockSums1.qStepGradStep;

    // Update the charge vector.
    for (int ii = GLOBAL_ID; ii < NUM_ELECTRODE_PARTICLES; ii += GLOBAL_SIZE) {
        q[ii] += alpha * qStep[ii];
    }

#ifdef USE_CHARGE_CONSTRAINT
    // Remove any accumulated drift from the charge vector.  This would be zero
    // in exact arithmetic, but error can accumulate over time in finite
    // precision.
    const real offset = (chargeTarget - (blockSums1.q + alpha * blockSums1.qStep)) / NUM_ELECTRODE_PARTICLES;
    for (int ii = GLOBAL_ID; ii < NUM_ELECTRODE_PARTICLES; ii += GLOBAL_SIZE) {
        q[ii] += offset;
    }
#endif

    // Update the gradient vector.  If on this iteration, the gradient is to be
    // recomputed, the contents of grad will be overwritten.
    for (int ii = GLOBAL_ID; ii < NUM_ELECTRODE_PARTICLES; ii += GLOBAL_SIZE) {
        grad[ii] += alpha * gradStep[ii];
    }
}

KERNEL void solveLoopStep4(
    GLOBAL real* RESTRICT grad,
    GLOBAL real* RESTRICT projGrad,
    GLOBAL real* RESTRICT precGrad,
    GLOBAL real* RESTRICT gradStep,
    GLOBAL BlockSums2* RESTRICT blockSums2Buffer,
    GLOBAL int* RESTRICT convergedResult
#ifdef PRECOND_REQUESTED
    , GLOBAL ACCUM* RESTRICT precondVector,
    int precondActivated
#endif
) {
    // This kernel expects to be executed in a single thread block.

    if (convergedResult[0] != 0) {
        return;
    }

    LOCAL volatile real temp[TEMP_SIZE];
    LOCAL BlockSums2 tempSums[TEMP_SIZE];
#ifdef USE_CHARGE_CONSTRAINT
    LOCAL volatile ACCUM tempAccum[TEMP_SIZE];
#endif

    // Project the current gradient without preconditioning.
#ifdef USE_CHARGE_CONSTRAINT
    ACCUM offsetAccum = ACCUM_ZERO;
    for (int ii = LOCAL_ID; ii < NUM_ELECTRODE_PARTICLES; ii += LOCAL_SIZE) {
        offsetAccum = ACCUM_ADD(offsetAccum, grad[ii]);
    }
    ACCUM offset = reduceAccum(offsetAccum, tempAccum, 0, -1 / (real) NUM_ELECTRODE_PARTICLES);
    for (int ii = LOCAL_ID; ii < NUM_ELECTRODE_PARTICLES; ii += LOCAL_SIZE) {
        projGrad[ii] = ACCUM_APPLY(grad[ii], offset);
    }
#else
    for (int ii = LOCAL_ID; ii < NUM_ELECTRODE_PARTICLES; ii += LOCAL_SIZE) {
        projGrad[ii] = grad[ii];
    }
#endif

    // Project the current gradient with preconditioning.
#ifdef PRECOND_REQUESTED
    if (precondActivated) {
#ifdef USE_CHARGE_CONSTRAINT
        offsetAccum = ACCUM_ZERO;
        for (int ii = LOCAL_ID; ii < NUM_ELECTRODE_PARTICLES; ii += LOCAL_SIZE) {
            offsetAccum = ACCUM_MUL_ADD(precondVector[ii], grad[ii], offsetAccum);
        }
        offset = reduceAccum(offsetAccum, tempAccum, 0, -1);
        for (int ii = LOCAL_ID; ii < NUM_ELECTRODE_PARTICLES; ii += LOCAL_SIZE) {
            precGrad[ii] = ACCUM_ADD_MUL(grad[ii], offset, precondVector[ii]);
        }
#else
        for (int ii = LOCAL_ID; ii < NUM_ELECTRODE_PARTICLES; ii += LOCAL_SIZE) {
            precGrad[ii] = ACCUM_MUL(precondVector[ii], grad[ii]);
        }
#endif
    }
    else {
#endif
        for (int ii = LOCAL_ID; ii < NUM_ELECTRODE_PARTICLES; ii += LOCAL_SIZE) {
            precGrad[ii] = projGrad[ii];
        }
#ifdef PRECOND_REQUESTED
    }
#endif

    // Reduce values to be used by all blocks in the final kernel.
    BlockSums2 blockSums2 = {0, 0, 0};
    for (int ii = LOCAL_ID; ii < NUM_ELECTRODE_PARTICLES; ii += LOCAL_SIZE) {
        blockSums2.projGradSq += projGrad[ii] * projGrad[ii];
        blockSums2.precGradStep += precGrad[ii] * gradStep[ii];
        blockSums2.precGrad += precGrad[ii];
    }
    blockSums2 = reduceBlockSums2(blockSums2, tempSums);

    if (LOCAL_ID == 0) {
        blockSums2Buffer[0] = blockSums2;
        convergedResult[0] = (int) (blockSums2.projGradSq <= ERROR_TARGET);
    }
}

KERNEL void solveLoopStep5(
    GLOBAL real* RESTRICT electrodeCharges,
    GLOBAL real* RESTRICT precGrad,
    GLOBAL real* RESTRICT qStep,
    GLOBAL BlockSums1* RESTRICT blockSums1Buffer,
    GLOBAL BlockSums2* RESTRICT blockSums2Buffer,
    GLOBAL int* RESTRICT convergedResult
) {
    // This kernel can be executed across multiple thread blocks.

    if (convergedResult[0] != 0) {
        return;
    }

    const BlockSums1 blockSums1 = blockSums1Buffer[0];
    const BlockSums2 blockSums2 = blockSums2Buffer[0];

    // Evaluate the conjugate gradient parameter beta.
    const real beta = blockSums2.precGradStep / blockSums1.qStepGradStep;

    // Update the step vector.
    for (int ii = GLOBAL_ID; ii < NUM_ELECTRODE_PARTICLES; ii += GLOBAL_SIZE) {
        qStep[ii] = beta * qStep[ii] - precGrad[ii];
    }

#ifdef USE_CHARGE_CONSTRAINT
    // Project out any deviation off of the constraint plane from the step
    // vector.  This would be zero in exact arithmetic, but error can accumulate
    // over time in finite precision.
    const real offset = (beta * blockSums1.qStep - blockSums2.precGrad) / NUM_ELECTRODE_PARTICLES;
    for (int ii = GLOBAL_ID; ii < NUM_ELECTRODE_PARTICLES; ii += GLOBAL_SIZE) {
        qStep[ii] -= offset;
    }
#endif

    // Prepare for the next derivative calculation.
    for (int ii = GLOBAL_ID; ii < NUM_ELECTRODE_PARTICLES; ii += GLOBAL_SIZE) {
        electrodeCharges[ii] = qStep[ii];
    }
}