Commit 86ad07f7 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Bug fix

parent 8cc99c29
#include <torch/extension.h>
// CUDA forward declaration
void fused_strided_check_finite(at::Tensor & noop, at::Tensor & p_copy, int stride, int clear_overflow_first);
void fused_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_adam_undo_cuda(at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment