Commit 2619f1cb authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Resolve merge conflict

parent 91a5a87e
...@@ -8,13 +8,10 @@ void fused_reversible_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor ...@@ -8,13 +8,10 @@ void fused_reversible_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor
void fused_maybe_adam_undo_cuda(at::Tensor & overflow_flag, at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay); void fused_maybe_adam_undo_cuda(at::Tensor & overflow_flag, at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_adam_cuda_mt(int chunk_size, at::Tensor overflow_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay); void fused_adam_cuda_mt(int chunk_size, at::Tensor overflow_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_maybe_adam_undo_cuda_mt(int chunk_size, at::Tensor overflow_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void maybe_cast_cuda(at::Tensor & overflow_flag, at::Tensor & p_in, at::Tensor & p_out); void maybe_cast_cuda(at::Tensor & overflow_flag, at::Tensor & p_in, at::Tensor & p_out);
void maybe_cast_cuda_mt(int chunk_size, at::Tensor overflow_flag, std::vector<std::vector<at::Tensor>> tensor_lists); void maybe_cast_cuda_mt(int chunk_size, at::Tensor overflow_flag, std::vector<std::vector<at::Tensor>> tensor_lists);
void update_step_and_loss_scaler_cuda(at::Tensor & overflow_flag, at::Tensor & step_and_loss_scaler);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
...@@ -84,8 +81,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -84,8 +81,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("reversible_adam", &reversible_adam, "Reversible Adam optimized CUDA implementation."); m.def("reversible_adam", &reversible_adam, "Reversible Adam optimized CUDA implementation.");
m.def("adam_mt", &fused_adam_cuda_mt, "Multi tensor Adam optimized CUDA implementation."); m.def("adam_mt", &fused_adam_cuda_mt, "Multi tensor Adam optimized CUDA implementation.");
m.def("maybe_adam_undo", &maybe_adam_undo, "Undo function for Adam optimized CUDA implementation."); m.def("maybe_adam_undo", &maybe_adam_undo, "Undo function for Adam optimized CUDA implementation.");
m.def("maybe_adam_undo_mt", &fused_maybe_adam_undo_cuda_mt, "Multi tensor undo function for Adam optimized CUDA implementation.");
m.def("maybe_cast", &maybe_cast, "Unpack byte tensor containing e5m2 floats."); m.def("maybe_cast", &maybe_cast, "Unpack byte tensor containing e5m2 floats.");
m.def("maybe_cast_mt", &maybe_cast_cuda_mt, "Unpack byte tensor containing e5m2 floats."); m.def("maybe_cast_mt", &maybe_cast_cuda_mt, "Unpack byte tensor containing e5m2 floats.");
m.def("update_step_and_loss_scaler", &update_step_and_loss_scaler_cuda, "Update step and loss scaler.");
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment