Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
4a01ff26
Commit
4a01ff26
authored
Apr 16, 2020
by
Thor Johnsen
Browse files
Partial move towards syncfree optimizer
parent
2622d7f1
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
224 additions
and
178 deletions
+224
-178
apex/contrib/csrc/optimizers/fused_adam_cuda.cpp
apex/contrib/csrc/optimizers/fused_adam_cuda.cpp
+19
-16
apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu
apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu
+197
-160
apex/contrib/optimizers/distributed_fused_adam.py
apex/contrib/optimizers/distributed_fused_adam.py
+8
-2
No files found.
apex/contrib/csrc/optimizers/fused_adam_cuda.cpp
View file @
4a01ff26
#include <torch/extension.h>
// CUDA forward declaration
void
fused_strided_check_finite
(
at
::
Tensor
&
noop
,
at
::
Tensor
&
p_copy
,
int
stride
,
int
clear_overflow_first
);
void
fused_strided_check_finite
(
at
::
Tensor
&
overflow_flag
,
at
::
Tensor
&
p_copy
,
int
stride
,
int
clear_overflow_first
);
void
fused_adam_cuda
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
void
fused_adam_undo_cuda
(
at
::
Tensor
&
p
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
void
fused_
maybe_
adam_undo_cuda
(
at
::
Tensor
&
overflow_flag
,
at
::
Tensor
&
p
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
void
fused_adam_cuda_mt
(
int
chunk_size
,
at
::
Tensor
noop
_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
void
fused_adam_undo_cuda_mt
(
int
chunk_size
,
at
::
Tensor
noop
_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
void
fused_adam_cuda_mt
(
int
chunk_size
,
at
::
Tensor
overflow
_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
void
fused_
maybe_
adam_undo_cuda_mt
(
int
chunk_size
,
at
::
Tensor
overflow
_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
void
unpack_e5m2_cuda
(
at
::
Tensor
&
p_in
,
at
::
Tensor
&
p_out
);
void
unpack_e5m2_cuda_mt
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
);
void
maybe_cast_cuda
(
at
::
Tensor
&
overflow_flag
,
at
::
Tensor
&
p_in
,
at
::
Tensor
&
p_out
);
void
maybe_cast_cuda_mt
(
int
chunk_size
,
at
::
Tensor
overflow_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
);
void
update_step_and_loss_scaler_cuda
(
at
::
Tensor
&
overflow_flag
,
at
::
Tensor
&
step_and_loss_scaler
);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
...
...
@@ -18,13 +20,13 @@ void unpack_e5m2_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std::
// C++ interface
void
strided_check_finite
(
at
::
Tensor
&
noop
,
at
::
Tensor
&
overflow_flag
,
at
::
Tensor
&
p_copy
,
int
stride
,
int
clear_overflow_first
)
{
CHECK_INPUT
(
p_copy
);
fused_strided_check_finite
(
noop
,
p_copy
,
stride
,
clear_overflow_first
);
fused_strided_check_finite
(
overflow_flag
,
p_copy
,
stride
,
clear_overflow_first
);
}
void
adam
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
)
{
CHECK_INPUT
(
p
);
...
...
@@ -40,7 +42,7 @@ void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, a
fused_adam_cuda
(
p
,
p_copy
,
m
,
v
,
g
,
lr
,
beta1
,
beta2
,
eps
,
grad_scale
,
step
,
mode
,
bias_correction
,
decay
);
}
void
adam_undo
(
at
::
Tensor
&
p
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
)
{
void
maybe_
adam_undo
(
at
::
Tensor
&
overflow_flag
,
at
::
Tensor
&
p
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
)
{
CHECK_INPUT
(
p
);
CHECK_INPUT
(
m
);
CHECK_INPUT
(
v
);
...
...
@@ -50,23 +52,24 @@ void adam_undo(at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, f
AT_ASSERTM
(
v
.
numel
()
==
num_elem
,
"number of elements in v and p tensors should be equal"
);
AT_ASSERTM
(
g
.
numel
()
==
num_elem
,
"number of elements in g and p tensors should be equal"
);
fused_adam_undo_cuda
(
p
,
m
,
v
,
g
,
lr
,
beta1
,
beta2
,
eps
,
grad_scale
,
step
,
mode
,
bias_correction
,
decay
);
fused_
maybe_
adam_undo_cuda
(
overflow_flag
,
p
,
m
,
v
,
g
,
lr
,
beta1
,
beta2
,
eps
,
grad_scale
,
step
,
mode
,
bias_correction
,
decay
);
}
void
unpack_e5m2
(
at
::
Tensor
&
p_in
,
at
::
Tensor
&
p_out
)
{
void
maybe_cast
(
at
::
Tensor
&
overflow_flag
,
at
::
Tensor
&
p_in
,
at
::
Tensor
&
p_out
)
{
CHECK_INPUT
(
p_in
);
CHECK_INPUT
(
p_out
);
int64_t
num_elem
=
p_in
.
numel
();
AT_ASSERTM
(
p_out
.
numel
()
==
num_elem
,
"number of elements in p_in and p_out should be equal"
);
unpack_e5m2_cuda
(
p_in
,
p_out
);
maybe_cast_cuda
(
overflow_flag
,
p_in
,
p_out
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"strided_check_finite"
,
&
strided_check_finite
,
"Strided finite check."
);
m
.
def
(
"adam"
,
&
adam
,
"Adam optimized CUDA implementation."
);
m
.
def
(
"adam_undo"
,
&
adam_undo
,
"Undo function for Adam optimized CUDA implementation."
);
m
.
def
(
"adam_mt"
,
&
fused_adam_cuda_mt
,
"Multi tensor Adam optimized CUDA implementation."
);
m
.
def
(
"adam_undo_mt"
,
&
fused_adam_undo_cuda_mt
,
"Multi tensor undo function for Adam optimized CUDA implementation."
);
m
.
def
(
"unpack_e5m2"
,
&
unpack_e5m2
,
"Unpack byte tensor containing e5m2 floats."
);
m
.
def
(
"unpack_e5m2_mt"
,
&
unpack_e5m2_cuda_mt
,
"Unpack byte tensor containing e5m2 floats."
);
m
.
def
(
"maybe_adam_undo"
,
&
maybe_adam_undo
,
"Undo function for Adam optimized CUDA implementation."
);
m
.
def
(
"maybe_adam_undo_mt"
,
&
fused_maybe_adam_undo_cuda_mt
,
"Multi tensor undo function for Adam optimized CUDA implementation."
);
m
.
def
(
"maybe_cast"
,
&
maybe_cast
,
"Unpack byte tensor containing e5m2 floats."
);
m
.
def
(
"maybe_cast_mt"
,
&
maybe_cast_cuda_mt
,
"Unpack byte tensor containing e5m2 floats."
);
m
.
def
(
"update_step_and_loss_scaler"
,
&
update_step_and_loss_scaler_cuda
,
"Update step and loss scaler."
);
}
apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu
View file @
4a01ff26
This diff is collapsed.
Click to expand it.
apex/contrib/optimizers/distributed_fused_adam.py
View file @
4a01ff26
...
...
@@ -154,6 +154,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
if
torch
.
distributed
.
get_rank
()
in
ranks
:
self
.
_ar_pg
.
append
(
grp
)
self
.
_ar_st
=
[
torch
.
cuda
.
Stream
()
for
_
in
range
(
self
.
_num_ar_pg
)]
for
ar_pg
in
self
.
_ar_pg
:
torch
.
distributed
.
all_reduce
(
self
.
_overflow_buf
,
group
=
ar_pg
)
rs_ranks
=
[]
for
group_i
in
range
(
self
.
_num_groups
):
rs_ranks
.
append
([
group_i
*
self
.
_group_size
+
j
for
j
in
range
(
self
.
_group_size
)])
...
...
@@ -166,6 +168,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self
.
_rs_pg
.
append
(
grp
)
if
self
.
_compute_L2_grad_norm
and
torch
.
distributed
.
get_rank
()
in
ranks
:
self
.
_l2_grad_norm_pg
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
torch
.
distributed
.
all_reduce
(
self
.
_overflow_buf
,
group
=
self
.
_l2_grad_norm_pg
)
self
.
_rs_st
=
[
torch
.
cuda
.
Stream
()
for
_
in
range
(
self
.
_num_rs_pg
)]
if
self
.
_num_ag_pg
==
0
:
self
.
_ag_pg
=
self
.
_rs_pg
...
...
@@ -180,6 +183,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
if
torch
.
distributed
.
get_rank
()
in
ranks
:
self
.
_ag_pg
.
append
(
grp
)
self
.
_ag_st
=
[
torch
.
cuda
.
Stream
()
for
_
in
range
(
self
.
_num_ag_pg
)]
for
ag_pg
in
self
.
_ag_pg
:
torch
.
distributed
.
all_reduce
(
self
.
_overflow_buf
,
group
=
ag_pg
)
self
.
_l2_grad_norm_st
=
torch
.
cuda
.
Stream
()
if
self
.
_compute_L2_grad_norm
else
None
self
.
_completion_st
=
torch
.
cuda
.
Stream
()
...
...
@@ -452,7 +457,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
beta1
,
beta2
=
group
[
'betas'
]
if
undo
:
if
self
.
_revert_method
==
1
:
fused_adam_cuda
.
adam_undo
(
fused_adam_cuda
.
maybe_adam_undo
(
torch
.
empty
([
0
]),
self
.
_fp32_p
[
group_buffer_start
:
group_buffer_end
],
self
.
_fp32_m
[
group_buffer_start
:
group_buffer_end
],
self
.
_fp32_v
[
group_buffer_start
:
group_buffer_end
],
...
...
@@ -576,7 +582,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
param_i
+=
1
if
self
.
_e5m2_allgather
:
multi_tensor_applier
(
fused_adam_cuda
.
unpack_e5m2
_mt
,
fused_adam_cuda
.
maybe_cast
_mt
,
self
.
_overflow_buf
,
[
p_in
,
p_out
]);
elif
self
.
_do_not_flatten_model
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment