Commit a5bc76db authored by Michael Carilli's avatar Michael Carilli
Browse files

Tests and resnet50 example work

parent 6e9159d8
......@@ -6,7 +6,7 @@ void scale_check_overflow_cuda(const at::Tensor& grads,
const at::Tensor& downscaled_grads);
void scale_check_overflow(at::Tensor grads,
float scale,
float scale,
at::Tensor overflow_buf,
at::Tensor downscaled_grads)
// const at::optional<at::Tensor> downscaled_grads)
......@@ -18,7 +18,7 @@ void scale_check_overflow(at::Tensor grads,
AT_CHECK(downscaled_grads.type().is_cuda(), "downscaled_grads must be a CUDA tensor");
AT_CHECK(downscaled_grads.is_contiguous(), "downscaled_grads must be contiguous");
// Make sure we are downscaling the FP32 master grads
AT_CHECK(downscaled_grads.type().scalarType() == at::ScalarType::Float,
AT_CHECK(downscaled_grads.type().scalarType() == at::ScalarType::Float,
"The output grads supplied to scale_check_overflow should be fp32 (master grads).")
AT_CHECK(grads.numel() == downscaled_grads.numel(), "Input and output grads must be the same size.");
......
......@@ -16,10 +16,10 @@
// This can be optimized with ILP but it's fine for now.
template<typename in_t>
__global__ void scale_reduce_overflow(in_t* in,
float* out,
float* out,
int n,
float scale,
volatile int* overflow_global)
volatile int* overflow_global)
{
__shared__ int overflow;
......@@ -36,7 +36,7 @@ __global__ void scale_reduce_overflow(in_t* in,
if(overflow == 1)
break;
if(i < n)
{
float incoming_val = static_cast<float>(in[i]);
......@@ -47,16 +47,16 @@ __global__ void scale_reduce_overflow(in_t* in,
// This is NOT guaranteed to be seen immediately by thread 0 on the next iteration.
// I wonder if there's a way we can rig the short-circuiting with only one syncthreads.
// It's possible we can just lean on the cache (no smem or syncs) and still be fast.
}
}
}
}
}
void scale_check_overflow_cuda
(const at::Tensor& grads,
(const at::Tensor& grads,
float scale,
const at::Tensor& overflow_buf,
const at::Tensor& downscaled_grads)
const at::Tensor& downscaled_grads)
{
using namespace at;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
......@@ -70,10 +70,10 @@ void scale_check_overflow_cuda
{
// using accscalar_t = acc_type<scalar_t, true>;
scale_reduce_overflow<<<NBLOCKS, BLOCK_SIZE, 0, stream>>>
(grads.data<scalar_t>(),
(grads.data<scalar_t>(),
downscaled_grads.data<float>(),
n,
scale,
n,
scale,
overflow_buf.data<int>());
});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment