Commit 6e9159d8 authored by Michael Carilli's avatar Michael Carilli
Browse files

ready for testing

parent 337056c1
...@@ -20,6 +20,7 @@ void scale_check_overflow(at::Tensor grads, ...@@ -20,6 +20,7 @@ void scale_check_overflow(at::Tensor grads,
// Make sure we are downscaling the FP32 master grads // Make sure we are downscaling the FP32 master grads
AT_CHECK(downscaled_grads.type().scalarType() == at::ScalarType::Float, AT_CHECK(downscaled_grads.type().scalarType() == at::ScalarType::Float,
"The output grads supplied to scale_check_overflow should be fp32 (master grads).") "The output grads supplied to scale_check_overflow should be fp32 (master grads).")
AT_CHECK(grads.numel() == downscaled_grads.numel(), "Input and output grads must be the same size.");
scale_check_overflow_cuda(grads, scale, overflow_buf, downscaled_grads); scale_check_overflow_cuda(grads, scale, overflow_buf, downscaled_grads);
} }
......
...@@ -7,16 +7,17 @@ ...@@ -7,16 +7,17 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#define BLOCK_SIZE 1024 #define BLOCK_SIZE 1024
#define MAX_BLOCKS 1024 #define NBLOCKS 160
// It makes sense to lock the output type to fp32 because the downscaled // It makes sense to lock the output type to fp32 because the downscaled
// grads should be master grads (and in the case of Amp, the params and their // grads should be master grads (and in the case of Amp, the params and their
// gradients should always be fp32. // gradients should always be fp32.
// This can be optimized with ILP but it's fine for now.
template<typename in_t> template<typename in_t>
__global__ void scale_reduce_overflow(in_t* in, __global__ void scale_reduce_overflow(in_t* in,
float* out, float* out,
size_t n, int n,
float scale, float scale,
volatile int* overflow_global) volatile int* overflow_global)
{ {
...@@ -36,13 +37,16 @@ __global__ void scale_reduce_overflow(in_t* in, ...@@ -36,13 +37,16 @@ __global__ void scale_reduce_overflow(in_t* in,
if(overflow == 1) if(overflow == 1)
break; break;
if(tid < n) if(i < n)
{ {
float incoming_val = static_cast<float>(in[i]); float incoming_val = static_cast<float>(in[i]);
if(isfinite(incoming_val)) if(isfinite(incoming_val))
out[i] = incoming_val*scale; out[i] = incoming_val*scale;
else else
*overflow_global = 1; // Blindly fire off a write. These will race but that's ok. *overflow_global = 1; // Blindly fire off a write. These will race but that's ok.
// This is NOT guaranteed to be seen immediately by thread 0 on the next iteration.
// I wonder if there's a way we can rig the short-circuiting with only one syncthreads.
// It's possible we can just lean on the cache (no smem or syncs) and still be fast.
} }
} }
} }
...@@ -57,17 +61,15 @@ void scale_check_overflow_cuda ...@@ -57,17 +61,15 @@ void scale_check_overflow_cuda
using namespace at; using namespace at;
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
size_t n = grads.numel(); int n = grads.numel();
int num_blks = 160;
// Lock the output (downscaled) type to float. // Lock the output (downscaled) type to float.
AT_DISPATCH_FLOATING_TYPES_AND_HALF(grads.type(), AT_DISPATCH_FLOATING_TYPES_AND_HALF(grads.type(),
"scale_check_overflow_cuda", "scale_check_overflow_cuda",
[&] [&]
{ {
// using accscalar_t = acc_type<scalar_t, true>; // using accscalar_t = acc_type<scalar_t, true>;
scale_reduce_overflow<<<num_blks, BLOCK_SIZE, 0, stream>>> scale_reduce_overflow<<<NBLOCKS, BLOCK_SIZE, 0, stream>>>
(grads.data<scalar_t>(), (grads.data<scalar_t>(),
downscaled_grads.data<float>(), downscaled_grads.data<float>(),
n, n,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment