Commit a5bc76db authored by Michael Carilli's avatar Michael Carilli
Browse files

Tests and resnet50 example work

parent 6e9159d8
...@@ -6,7 +6,7 @@ void scale_check_overflow_cuda(const at::Tensor& grads, ...@@ -6,7 +6,7 @@ void scale_check_overflow_cuda(const at::Tensor& grads,
const at::Tensor& downscaled_grads); const at::Tensor& downscaled_grads);
void scale_check_overflow(at::Tensor grads, void scale_check_overflow(at::Tensor grads,
float scale, float scale,
at::Tensor overflow_buf, at::Tensor overflow_buf,
at::Tensor downscaled_grads) at::Tensor downscaled_grads)
// const at::optional<at::Tensor> downscaled_grads) // const at::optional<at::Tensor> downscaled_grads)
...@@ -18,7 +18,7 @@ void scale_check_overflow(at::Tensor grads, ...@@ -18,7 +18,7 @@ void scale_check_overflow(at::Tensor grads,
AT_CHECK(downscaled_grads.type().is_cuda(), "downscaled_grads must be a CUDA tensor"); AT_CHECK(downscaled_grads.type().is_cuda(), "downscaled_grads must be a CUDA tensor");
AT_CHECK(downscaled_grads.is_contiguous(), "downscaled_grads must be contiguous"); AT_CHECK(downscaled_grads.is_contiguous(), "downscaled_grads must be contiguous");
// Make sure we are downscaling the FP32 master grads // Make sure we are downscaling the FP32 master grads
AT_CHECK(downscaled_grads.type().scalarType() == at::ScalarType::Float, AT_CHECK(downscaled_grads.type().scalarType() == at::ScalarType::Float,
"The output grads supplied to scale_check_overflow should be fp32 (master grads).") "The output grads supplied to scale_check_overflow should be fp32 (master grads).")
AT_CHECK(grads.numel() == downscaled_grads.numel(), "Input and output grads must be the same size."); AT_CHECK(grads.numel() == downscaled_grads.numel(), "Input and output grads must be the same size.");
......
...@@ -16,10 +16,10 @@ ...@@ -16,10 +16,10 @@
// This can be optimized with ILP but it's fine for now. // This can be optimized with ILP but it's fine for now.
template<typename in_t> template<typename in_t>
__global__ void scale_reduce_overflow(in_t* in, __global__ void scale_reduce_overflow(in_t* in,
float* out, float* out,
int n, int n,
float scale, float scale,
volatile int* overflow_global) volatile int* overflow_global)
{ {
__shared__ int overflow; __shared__ int overflow;
...@@ -36,7 +36,7 @@ __global__ void scale_reduce_overflow(in_t* in, ...@@ -36,7 +36,7 @@ __global__ void scale_reduce_overflow(in_t* in,
if(overflow == 1) if(overflow == 1)
break; break;
if(i < n) if(i < n)
{ {
float incoming_val = static_cast<float>(in[i]); float incoming_val = static_cast<float>(in[i]);
...@@ -47,16 +47,16 @@ __global__ void scale_reduce_overflow(in_t* in, ...@@ -47,16 +47,16 @@ __global__ void scale_reduce_overflow(in_t* in,
// This is NOT guaranteed to be seen immediately by thread 0 on the next iteration. // This is NOT guaranteed to be seen immediately by thread 0 on the next iteration.
// I wonder if there's a way we can rig the short-circuiting with only one syncthreads. // I wonder if there's a way we can rig the short-circuiting with only one syncthreads.
// It's possible we can just lean on the cache (no smem or syncs) and still be fast. // It's possible we can just lean on the cache (no smem or syncs) and still be fast.
} }
} }
} }
void scale_check_overflow_cuda void scale_check_overflow_cuda
(const at::Tensor& grads, (const at::Tensor& grads,
float scale, float scale,
const at::Tensor& overflow_buf, const at::Tensor& overflow_buf,
const at::Tensor& downscaled_grads) const at::Tensor& downscaled_grads)
{ {
using namespace at; using namespace at;
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
...@@ -70,10 +70,10 @@ void scale_check_overflow_cuda ...@@ -70,10 +70,10 @@ void scale_check_overflow_cuda
{ {
// using accscalar_t = acc_type<scalar_t, true>; // using accscalar_t = acc_type<scalar_t, true>;
scale_reduce_overflow<<<NBLOCKS, BLOCK_SIZE, 0, stream>>> scale_reduce_overflow<<<NBLOCKS, BLOCK_SIZE, 0, stream>>>
(grads.data<scalar_t>(), (grads.data<scalar_t>(),
downscaled_grads.data<float>(), downscaled_grads.data<float>(),
n, n,
scale, scale,
overflow_buf.data<int>()); overflow_buf.data<int>());
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment