"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "4344da3d3f6c766216f2ec932916433055af05f4"
Commit 4a01ff26 authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Partial move towards syncfree optimizer

parent 2622d7f1
#include <torch/extension.h> #include <torch/extension.h>
// CUDA forward declaration // CUDA forward declaration
void fused_strided_check_finite(at::Tensor & noop, at::Tensor & p_copy, int stride, int clear_overflow_first); void fused_strided_check_finite(at::Tensor & overflow_flag, at::Tensor & p_copy, int stride, int clear_overflow_first);
void fused_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay); void fused_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_adam_undo_cuda(at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay); void fused_maybe_adam_undo_cuda(at::Tensor & overflow_flag, at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_adam_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay); void fused_adam_cuda_mt(int chunk_size, at::Tensor overflow_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void fused_adam_undo_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay); void fused_maybe_adam_undo_cuda_mt(int chunk_size, at::Tensor overflow_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
void unpack_e5m2_cuda(at::Tensor & p_in, at::Tensor & p_out); void maybe_cast_cuda(at::Tensor & overflow_flag, at::Tensor & p_in, at::Tensor & p_out);
void unpack_e5m2_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists); void maybe_cast_cuda_mt(int chunk_size, at::Tensor overflow_flag, std::vector<std::vector<at::Tensor>> tensor_lists);
void update_step_and_loss_scaler_cuda(at::Tensor & overflow_flag, at::Tensor & step_and_loss_scaler);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
...@@ -18,13 +20,13 @@ void unpack_e5m2_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std:: ...@@ -18,13 +20,13 @@ void unpack_e5m2_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std::
// C++ interface // C++ interface
void strided_check_finite( void strided_check_finite(
at::Tensor& noop, at::Tensor& overflow_flag,
at::Tensor& p_copy, at::Tensor& p_copy,
int stride, int stride,
int clear_overflow_first int clear_overflow_first
) { ) {
CHECK_INPUT(p_copy); CHECK_INPUT(p_copy);
fused_strided_check_finite(noop, p_copy, stride, clear_overflow_first); fused_strided_check_finite(overflow_flag, p_copy, stride, clear_overflow_first);
} }
void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) { void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) {
CHECK_INPUT(p); CHECK_INPUT(p);
...@@ -40,7 +42,7 @@ void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, a ...@@ -40,7 +42,7 @@ void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, a
fused_adam_cuda(p, p_copy, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay); fused_adam_cuda(p, p_copy, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay);
} }
void adam_undo(at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) { void maybe_adam_undo(at::Tensor & overflow_flag, at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) {
CHECK_INPUT(p); CHECK_INPUT(p);
CHECK_INPUT(m); CHECK_INPUT(m);
CHECK_INPUT(v); CHECK_INPUT(v);
...@@ -50,23 +52,24 @@ void adam_undo(at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, f ...@@ -50,23 +52,24 @@ void adam_undo(at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, f
AT_ASSERTM(v.numel() == num_elem, "number of elements in v and p tensors should be equal"); AT_ASSERTM(v.numel() == num_elem, "number of elements in v and p tensors should be equal");
AT_ASSERTM(g.numel() == num_elem, "number of elements in g and p tensors should be equal"); AT_ASSERTM(g.numel() == num_elem, "number of elements in g and p tensors should be equal");
fused_adam_undo_cuda(p, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay); fused_maybe_adam_undo_cuda(overflow_flag, p, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay);
} }
void unpack_e5m2(at::Tensor & p_in, at::Tensor & p_out) { void maybe_cast(at::Tensor & overflow_flag, at::Tensor & p_in, at::Tensor & p_out) {
CHECK_INPUT(p_in); CHECK_INPUT(p_in);
CHECK_INPUT(p_out); CHECK_INPUT(p_out);
int64_t num_elem = p_in.numel(); int64_t num_elem = p_in.numel();
AT_ASSERTM(p_out.numel() == num_elem, "number of elements in p_in and p_out should be equal"); AT_ASSERTM(p_out.numel() == num_elem, "number of elements in p_in and p_out should be equal");
unpack_e5m2_cuda(p_in, p_out); maybe_cast_cuda(overflow_flag, p_in, p_out);
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("strided_check_finite", &strided_check_finite, "Strided finite check."); m.def("strided_check_finite", &strided_check_finite, "Strided finite check.");
m.def("adam", &adam, "Adam optimized CUDA implementation."); m.def("adam", &adam, "Adam optimized CUDA implementation.");
m.def("adam_undo", &adam_undo, "Undo function for Adam optimized CUDA implementation.");
m.def("adam_mt", &fused_adam_cuda_mt, "Multi tensor Adam optimized CUDA implementation."); m.def("adam_mt", &fused_adam_cuda_mt, "Multi tensor Adam optimized CUDA implementation.");
m.def("adam_undo_mt", &fused_adam_undo_cuda_mt, "Multi tensor undo function for Adam optimized CUDA implementation."); m.def("maybe_adam_undo", &maybe_adam_undo, "Undo function for Adam optimized CUDA implementation.");
m.def("unpack_e5m2", &unpack_e5m2, "Unpack byte tensor containing e5m2 floats."); m.def("maybe_adam_undo_mt", &fused_maybe_adam_undo_cuda_mt, "Multi tensor undo function for Adam optimized CUDA implementation.");
m.def("unpack_e5m2_mt", &unpack_e5m2_cuda_mt, "Unpack byte tensor containing e5m2 floats."); m.def("maybe_cast", &maybe_cast, "Unpack byte tensor containing e5m2 floats.");
m.def("maybe_cast_mt", &maybe_cast_cuda_mt, "Unpack byte tensor containing e5m2 floats.");
m.def("update_step_and_loss_scaler", &update_step_and_loss_scaler_cuda, "Update step and loss scaler.");
} }
...@@ -154,6 +154,8 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -154,6 +154,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
if torch.distributed.get_rank() in ranks: if torch.distributed.get_rank() in ranks:
self._ar_pg.append(grp) self._ar_pg.append(grp)
self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)] self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)]
for ar_pg in self._ar_pg:
torch.distributed.all_reduce(self._overflow_buf,group=ar_pg)
rs_ranks = [] rs_ranks = []
for group_i in range(self._num_groups): for group_i in range(self._num_groups):
rs_ranks.append([group_i*self._group_size+j for j in range(self._group_size)]) rs_ranks.append([group_i*self._group_size+j for j in range(self._group_size)])
...@@ -166,6 +168,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -166,6 +168,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self._rs_pg.append(grp) self._rs_pg.append(grp)
if self._compute_L2_grad_norm and torch.distributed.get_rank() in ranks: if self._compute_L2_grad_norm and torch.distributed.get_rank() in ranks:
self._l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks) self._l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks)
torch.distributed.all_reduce(self._overflow_buf,group=self._l2_grad_norm_pg)
self._rs_st = [torch.cuda.Stream() for _ in range(self._num_rs_pg)] self._rs_st = [torch.cuda.Stream() for _ in range(self._num_rs_pg)]
if self._num_ag_pg == 0: if self._num_ag_pg == 0:
self._ag_pg = self._rs_pg self._ag_pg = self._rs_pg
...@@ -180,6 +183,8 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -180,6 +183,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
if torch.distributed.get_rank() in ranks: if torch.distributed.get_rank() in ranks:
self._ag_pg.append(grp) self._ag_pg.append(grp)
self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)] self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)]
for ag_pg in self._ag_pg:
torch.distributed.all_reduce(self._overflow_buf,group=ag_pg)
self._l2_grad_norm_st = torch.cuda.Stream() if self._compute_L2_grad_norm else None self._l2_grad_norm_st = torch.cuda.Stream() if self._compute_L2_grad_norm else None
self._completion_st = torch.cuda.Stream() self._completion_st = torch.cuda.Stream()
...@@ -452,7 +457,8 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -452,7 +457,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
beta1, beta2 = group['betas'] beta1, beta2 = group['betas']
if undo: if undo:
if self._revert_method == 1: if self._revert_method == 1:
fused_adam_cuda.adam_undo( fused_adam_cuda.maybe_adam_undo(
torch.empty([0]),
self._fp32_p[group_buffer_start:group_buffer_end], self._fp32_p[group_buffer_start:group_buffer_end],
self._fp32_m[group_buffer_start:group_buffer_end], self._fp32_m[group_buffer_start:group_buffer_end],
self._fp32_v[group_buffer_start:group_buffer_end], self._fp32_v[group_buffer_start:group_buffer_end],
...@@ -576,7 +582,7 @@ class DistributedFusedAdam(torch.optim.Optimizer): ...@@ -576,7 +582,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
param_i += 1 param_i += 1
if self._e5m2_allgather: if self._e5m2_allgather:
multi_tensor_applier( multi_tensor_applier(
fused_adam_cuda.unpack_e5m2_mt, fused_adam_cuda.maybe_cast_mt,
self._overflow_buf, self._overflow_buf,
[p_in, p_out]); [p_in, p_out]);
elif self._do_not_flatten_model: elif self._do_not_flatten_model:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment