fused_adam_cuda.cpp 1.91 KB
Newer Older
Natalia Gimelshein's avatar
Natalia Gimelshein committed
1
#include <torch/extension.h>
2
3

// CUDA forward declaration
Deyu Fu's avatar
Deyu Fu committed
4
void fused_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
5

6
7
8
void fused_adam_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);


9
10
11
12
13
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

// C++ interface
Deyu Fu's avatar
Deyu Fu committed
14
void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) {
15
16
17
18
19
20
21
22
23
24
25
        CHECK_INPUT(p)
        if (p_copy.numel() > 0) CHECK_INPUT(p_copy);
        CHECK_INPUT(m);
        CHECK_INPUT(v);
        CHECK_INPUT(g);
        int64_t num_elem = p.numel();
        AT_ASSERTM(m.numel() == num_elem, "number of elements in m and p tensors should be equal");
        AT_ASSERTM(v.numel() == num_elem, "number of elements in v and p tensors should be equal");
        AT_ASSERTM(g.numel() == num_elem, "number of elements in g and p tensors should be equal");
        AT_ASSERTM(p_copy.numel() == num_elem || p_copy.numel() == 0, "number of elements in p_copy and p tensors should be equal, or p_copy should be empty");

Deyu Fu's avatar
Deyu Fu committed
26
        fused_adam_cuda(p, p_copy, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay);
27
28
29
30
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
        m.def("adam", &adam, "Adam optimized CUDA implementation.");
31
        m.def("adam_mt", &fused_adam_cuda_mt, "Multi tensor Adam optimized CUDA implementation.");
32
}