scale_check_overflow.cpp 1.5 KB
Newer Older
1
2
#include <torch/extension.h>

3
4
5
6
void scale_check_overflow_cuda(const at::Tensor& grads,
                               float scale,
                               const at::Tensor& d_buf,
                               const at::Tensor& downscaled_grads);
7

8
void scale_check_overflow(at::Tensor grads,
9
                          float scale,
10
11
12
                          at::Tensor overflow_buf,
                          at::Tensor downscaled_grads)
                          // const at::optional<at::Tensor> downscaled_grads)
13
14
15
16
17
{ 
  AT_CHECK(grads.type().is_cuda(), "grads must be a CUDA tensor");
  AT_CHECK(grads.is_contiguous(), "grads must be contiguous");
  AT_CHECK(overflow_buf.type().is_cuda(), "overflow_buf must be a CUDA tensor");
  AT_CHECK(overflow_buf.is_contiguous(), "overflow_buf must be contiguous");
18
19
  AT_CHECK(downscaled_grads.type().is_cuda(), "downscaled_grads must be a CUDA tensor");
  AT_CHECK(downscaled_grads.is_contiguous(), "downscaled_grads must be contiguous");
20
  // Make sure we are downscaling the FP32 master grads
21
  AT_CHECK(downscaled_grads.type().scalarType() == at::ScalarType::Float,
22
    "The output grads supplied to scale_check_overflow should be fp32 (master grads).")
Michael Carilli's avatar
Michael Carilli committed
23
  AT_CHECK(grads.numel() == downscaled_grads.numel(), "Input and output grads must be the same size.");
24

25
  scale_check_overflow_cuda(grads, scale, overflow_buf, downscaled_grads);
26
27
28
29
30
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("scale_check_overflow", &scale_check_overflow, "Fused overflow check + scale for FP32 tensors");
}