Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
08e88b1b
Unverified
Commit
08e88b1b
authored
Dec 01, 2021
by
athitten
Committed by
GitHub
Dec 01, 2021
Browse files
Enable Distributed FusedLAMB (#57)
parent
51b402df
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
401 additions
and
234 deletions
+401
-234
apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp
apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp
+7
-3
apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu
...ntrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu
+41
-23
apex/contrib/optimizers/distributed_fused_lamb.py
apex/contrib/optimizers/distributed_fused_lamb.py
+353
-208
No files found.
apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp
View file @
08e88b1b
...
...
@@ -8,11 +8,13 @@ void multi_tensor_lamb_compute_update_term_cuda(
at
::
Tensor
per_tensor_beta2
,
at
::
Tensor
per_tensor_beta3
,
at
::
Tensor
per_tensor_bias_correction
,
const
int
step
,
at
::
Tensor
step
,
at
::
Tensor
per_tensor_epsilon
,
const
int
mode
,
at
::
Tensor
per_tensor_decay
,
const
float
grad_scale
);
at
::
Tensor
global_scale
,
at
::
Tensor
global_grad_norm
,
const
float
max_grad_norm
);
void
multi_tensor_lamb_update_weights_cuda
(
int
chunk_size
,
...
...
@@ -20,8 +22,10 @@ void multi_tensor_lamb_update_weights_cuda(
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
Tensor
per_tensor_param_norm
,
at
::
Tensor
per_tensor_update_norm
,
const
float
learning_rate
,
at
::
Tensor
update_norm_offset
,
at
::
Tensor
learning_rate
,
at
::
Tensor
per_tensor_decay
,
at
::
Tensor
global_grad_norm
,
bool
use_nvlamb
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
...
...
apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu
View file @
08e88b1b
...
...
@@ -116,28 +116,36 @@ struct DistOptLAMBStage1Functor
const
MATH_T
*
per_tensor_beta2
,
const
MATH_T
*
per_tensor_beta3
,
const
int
*
per_tensor_bias_correction
,
const
int
step
,
const
int
*
step
,
const
MATH_T
*
per_tensor_epsilon
,
adamMode_t
mode
,
const
MATH_T
*
per_tensor_decay
,
const
float
grad_scale
)
const
MATH_T
*
global_scale
,
const
MATH_T
*
global_grad_norm
,
const
float
max_grad_norm
)
{
// I'd like this kernel to propagate infs/nans.
//
if(*noop_gmem == 1)
//
return;
if
(
*
noop_gmem
==
1
)
return
;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_num
=
tl
.
start_tensor_this_launch
+
tensor_loc
;
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
float
combined_scale
=
*
global_scale
;
if
(
max_grad_norm
>
0
)
{
combined_scale
=
max_grad_norm
/
(
*
global_grad_norm
/
*
global_scale
+
1e-6
);
combined_scale
=
*
global_scale
/
std
::
min
((
float
)
1.0
,
combined_scale
);
}
MATH_T
beta1
=
per_tensor_beta1
[
tensor_num
];
MATH_T
beta2
=
per_tensor_beta2
[
tensor_num
];
MATH_T
beta3
=
1
-
beta1
;
MATH_T
beta1_correction
,
beta2_correction
;
if
(
per_tensor_bias_correction
[
tensor_num
]
==
1
)
{
beta1_correction
=
1
-
pow
(
beta1
,
step
);
beta2_correction
=
1
-
pow
(
beta2
,
step
);
beta1_correction
=
1
-
pow
(
beta1
,
*
step
);
beta2_correction
=
1
-
pow
(
beta2
,
*
step
);
}
else
{
beta1_correction
=
(
MATH_T
)
1.0
;
beta2_correction
=
(
MATH_T
)
1.0
;
...
...
@@ -204,7 +212,7 @@ struct DistOptLAMBStage1Functor
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
if
(
mode
==
MOMENT_MODE_0
)
{
MATH_T
scaled_grad
=
r_g
[
ii
]
/
gra
d_scale
;
MATH_T
scaled_grad
=
r_g
[
ii
]
/
combine
d_scale
;
// L2 on scaled grad
scaled_grad
=
scaled_grad
+
decay
*
r_p
[
ii
];
r_m
[
ii
]
=
r_m
[
ii
]
*
beta1
+
beta3
*
scaled_grad
;
...
...
@@ -215,7 +223,7 @@ struct DistOptLAMBStage1Functor
r_p
[
ii
]
=
next_m_unbiased
/
denom
;
}
else
{
MATH_T
scaled_grad
=
r_g
[
ii
]
/
gra
d_scale
;
MATH_T
scaled_grad
=
r_g
[
ii
]
/
combine
d_scale
;
r_m
[
ii
]
=
r_m
[
ii
]
*
beta1
+
beta3
*
scaled_grad
;
r_v
[
ii
]
=
r_v
[
ii
]
*
beta2
+
(
1
-
beta2
)
*
scaled_grad
*
scaled_grad
;
MATH_T
next_m_unbiased
=
r_m
[
ii
]
/
beta1_correction
;
...
...
@@ -274,7 +282,7 @@ struct DistOptLAMBStage1Functor
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
if
(
mode
==
MOMENT_MODE_0
)
{
MATH_T
scaled_grad
=
r_g
[
ii
]
/
gra
d_scale
;
MATH_T
scaled_grad
=
r_g
[
ii
]
/
combine
d_scale
;
// L2 on scaled grad
scaled_grad
=
scaled_grad
+
decay
*
r_p
[
ii
];
r_m
[
ii
]
=
r_m
[
ii
]
*
beta1
+
beta3
*
scaled_grad
;
...
...
@@ -285,7 +293,7 @@ struct DistOptLAMBStage1Functor
r_p
[
ii
]
=
next_m_unbiased
/
denom
;
}
else
{
MATH_T
scaled_grad
=
r_g
[
ii
]
/
gra
d_scale
;
MATH_T
scaled_grad
=
r_g
[
ii
]
/
combine
d_scale
;
r_m
[
ii
]
=
r_m
[
ii
]
*
beta1
+
beta3
*
scaled_grad
;
r_v
[
ii
]
=
r_v
[
ii
]
*
beta2
+
(
1
-
beta2
)
*
scaled_grad
*
scaled_grad
;
MATH_T
next_m_unbiased
=
r_m
[
ii
]
/
beta1_correction
;
...
...
@@ -321,13 +329,15 @@ struct DistOptLAMBStage2Functor
TensorListMetadata
<
3
>&
tl
,
const
MATH_T
*
per_tensor_param_norm
,
const
MATH_T
*
per_tensor_update_norm
,
const
MATH_T
learning_rate
,
const
long
*
update_norm_offset
,
const
MATH_T
*
learning_rate
,
const
MATH_T
*
per_tensor_decay
,
const
MATH_T
*
global_grad_norm
,
bool
use_nvlamb
)
{
// I'd like this kernel to propagate infs/nans.
//
if(*noop_gmem == 1)
//
return;
if
(
*
noop_gmem
==
1
)
return
;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_num
=
tl
.
start_tensor_this_launch
+
tensor_loc
;
...
...
@@ -336,14 +346,14 @@ struct DistOptLAMBStage2Functor
MATH_T
decay
=
per_tensor_decay
[
tensor_num
];
MATH_T
ratio
=
learning_rate
;
MATH_T
ratio
=
*
learning_rate
;
// nvlamb: apply adaptive learning rate to all parameters
// otherwise, only apply to those with non-zero weight decay
if
(
use_nvlamb
||
(
decay
!=
(
MATH_T
)
0.0
))
{
MATH_T
param_norm
=
per_tensor_param_norm
[
tensor_num
];
MATH_T
update_norm
=
per_tensor_update_norm
[
tensor_num
];
ratio
=
(
update_norm
!=
0.0
&&
param_norm
!=
0.0
)
?
learning_rate
*
(
param_norm
/
update_norm
)
:
learning_rate
;
MATH_T
update_norm
=
per_tensor_update_norm
[
update_norm_offset
[
tensor_num
]
]
;
ratio
=
(
update_norm
!=
0.0
&&
param_norm
!=
0.0
)
?
(
*
learning_rate
)
*
(
param_norm
/
update_norm
)
:
(
*
learning_rate
)
;
}
MATH_T
*
update
=
(
MATH_T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
...
...
@@ -374,7 +384,7 @@ struct DistOptLAMBStage2Functor
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
r_p
[
ii
]
=
static_cast
<
MATH_T
>
(
r_p
[
ii
])
-
(
ratio
*
r_update
[
ii
]);
r_p
[
ii
]
=
static_cast
<
MATH_T
>
(
r_p
[
ii
])
-
(
ratio
*
r_update
[
ii
]);
convert
(
r_p
[
ii
],
r_p_copy
[
ii
]);
}
load_store
(
p
,
r_p
,
i_start
,
0
);
...
...
@@ -427,11 +437,13 @@ void multi_tensor_lamb_compute_update_term_cuda(
at
::
Tensor
per_tensor_beta2
,
at
::
Tensor
per_tensor_beta3
,
at
::
Tensor
per_tensor_bias_correction
,
const
int
step
,
at
::
Tensor
step
,
at
::
Tensor
per_tensor_epsilon
,
const
int
mode
,
at
::
Tensor
per_tensor_decay
,
const
float
grad_scale
)
at
::
Tensor
global_scale
,
at
::
Tensor
global_grad_norm
,
const
float
max_grad_norm
)
{
using
namespace
at
;
...
...
@@ -448,11 +460,13 @@ void multi_tensor_lamb_compute_update_term_cuda(
per_tensor_beta2
.
DATA_PTR
<
scalar_t_2
>
(),
per_tensor_beta3
.
DATA_PTR
<
scalar_t_2
>
(),
per_tensor_bias_correction
.
DATA_PTR
<
int
>
(),
step
,
step
.
DATA_PTR
<
int
>
()
,
per_tensor_epsilon
.
DATA_PTR
<
scalar_t_2
>
(),
(
adamMode_t
)
mode
,
per_tensor_decay
.
DATA_PTR
<
scalar_t_2
>
(),
grad_scale
);
)))
global_scale
.
DATA_PTR
<
scalar_t_2
>
(),
global_grad_norm
.
DATA_PTR
<
scalar_t_2
>
(),
max_grad_norm
);
)))
AT_CUDA_CHECK
(
cudaGetLastError
());
}
...
...
@@ -463,8 +477,10 @@ void multi_tensor_lamb_update_weights_cuda(
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
Tensor
per_tensor_param_norm
,
at
::
Tensor
per_tensor_update_norm
,
const
float
learning_rate
,
at
::
Tensor
update_norm_offset
,
at
::
Tensor
learning_rate
,
at
::
Tensor
per_tensor_decay
,
at
::
Tensor
global_grad_norm
,
bool
use_nvlamb
)
{
using
namespace
at
;
...
...
@@ -480,8 +496,10 @@ void multi_tensor_lamb_update_weights_cuda(
DistOptLAMBStage2Functor
<
scalar_t_0
,
scalar_t_1
,
scalar_t_2
>
(),
per_tensor_param_norm
.
DATA_PTR
<
scalar_t_2
>
(),
per_tensor_update_norm
.
DATA_PTR
<
scalar_t_2
>
(),
(
scalar_t_2
)
learning_rate
,
update_norm_offset
.
DATA_PTR
<
long
>
(),
learning_rate
.
DATA_PTR
<
scalar_t_2
>
(),
per_tensor_decay
.
DATA_PTR
<
scalar_t_2
>
(),
global_grad_norm
.
DATA_PTR
<
scalar_t_2
>
(),
use_nvlamb
);
)))
AT_CUDA_CHECK
(
cudaGetLastError
());
...
...
apex/contrib/optimizers/distributed_fused_lamb.py
View file @
08e88b1b
...
...
@@ -4,36 +4,38 @@ import importlib
import
amp_C
from
apex.multi_tensor_apply
import
multi_tensor_applier
import
torch.distributed.distributed_c10d
as
c10d
class
DistributedFusedLAMB
(
torch
.
optim
.
Optimizer
):
"""Implements LAMB algorithm.
Currently GPU-only. Requires Apex to be installed via
``pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./``.
This version of fused LAMB implements 2 fusions.
* Fusion of the LAMB update's elementwise operations
* A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.
:class:`apex.optimizers.FusedLAMB`'s usage is identical to any ordinary Pytorch optimizer::
opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....)
...
opt.step()
:class:`apex.optimizers.FusedLAMB` may be used with or without Amp. If you wish to use :class:`FusedLAMB` with Amp,
you may choose any ``opt_level``::
opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....)
model, opt = amp.initialize(model, opt, opt_level="O0" or "O1 or "O2")
...
opt.step()
In general, ``opt_level="O1"`` is recommended.
LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
...
...
@@ -56,24 +58,36 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
(default: 1.0)
use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0
weight decay parameter (default: False)
clip_grad_norm
(boolean, optional): whether to
handle gradient clipping
(default: True)
step_supports_amp_scaling
(boolean, optional): whether to
use customized
gradient unscaling logic
(default: True)
.. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
https://arxiv.org/abs/1904.00962
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
class
AtomicCounter
(
object
):
def
__init__
(
self
):
self
.
value
=
0
self
.
order
=
[]
import
threading
self
.
_lock
=
threading
.
Lock
()
def
add
(
self
,
idx
):
with
self
.
_lock
:
self
.
value
+=
1
self
.
order
.
append
(
idx
)
def
__init__
(
self
,
params
,
lr
=
1e-3
,
bias_correction
=
True
,
grad_averaging
=
True
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0.
,
max_grad_norm
=
0.
,
adam_w_mode
=
True
,
use_nvlamb
=
False
,
clip_grad_norm
=
True
,
amp_scale_adjustment
=
1.0
,
overlap_reductions
=
True
,
adam_w_mode
=
True
,
use_nvlamb
=
False
,
step_supports_amp_scaling
=
True
,
overlap_reductions
=
True
,
dwu_group_size
=
0
,
dwu_num_blocks
=
4
,
dwu_num_chunks
=
4
,
dwu_num_rs_pg
=
1
,
dwu_num_ar_pg
=
4
,
dwu_num_ag_pg
=
0
,
e5m2_allgather
=
False
):
e5m2_allgather
=
False
,
verbose
=
False
,
clip_after_ar
=
True
):
defaults
=
dict
(
lr
=
lr
,
bias_correction
=
bias_correction
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
,
grad_averaging
=
grad_averaging
,
...
...
@@ -81,46 +95,10 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
super
(
DistributedFusedLAMB
,
self
).
__init__
(
params
,
defaults
)
self
.
_init_args
=
{
'lr'
:
lr
,
'bias_correction'
:
bias_correction
,
'grad_averaging'
:
grad_averaging
,
'betas'
:
betas
,
'eps'
:
eps
,
'weight_decay'
:
weight_decay
,
'max_grad_norm'
:
max_grad_norm
,
'adam_w_mode'
:
adam_w_mode
,
'use_nvlamb'
:
use_nvlamb
,
'clip_grad_norm'
:
clip_grad_norm
,
'amp_scale_adjustment'
:
amp_scale_adjustment
,
'overlap_reductions'
:
overlap_reductions
,
'dwu_group_size'
:
dwu_group_size
,
'dwu_num_blocks'
:
dwu_num_blocks
,
'dwu_num_chunks'
:
dwu_num_chunks
,
'dwu_num_rs_pg'
:
dwu_num_rs_pg
,
'dwu_num_ar_pg'
:
dwu_num_ar_pg
,
'dwu_num_ag_pg'
:
dwu_num_ag_pg
,
'e5m2_allgather'
:
e5m2_allgather
}
self
.
_init_done
=
False
import
inspect
assert
(
'no_copy'
in
inspect
.
getfullargspec
(
torch
.
distributed
.
reduce_scatter
).
args
),
"This version of c10d does not support no_copy option"
def
__first_step_init__
(
self
,
lr
=
1e-3
,
bias_correction
=
True
,
grad_averaging
=
True
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
weight_decay
=
0.
,
max_grad_norm
=
0.
,
adam_w_mode
=
True
,
use_nvlamb
=
False
,
clip_grad_norm
=
True
,
amp_scale_adjustment
=
1.0
,
overlap_reductions
=
True
,
dwu_group_size
=
0
,
dwu_num_blocks
=
4
,
dwu_num_chunks
=
4
,
dwu_num_rs_pg
=
1
,
dwu_num_ar_pg
=
4
,
dwu_num_ag_pg
=
0
,
e5m2_allgather
=
False
):
global
fused_adam_cuda
,
distributed_lamb_cuda
fused_adam_cuda
=
importlib
.
import_module
(
"fused_adam_cuda"
)
distributed_lamb_cuda
=
importlib
.
import_module
(
"distributed_lamb_cuda"
)
self
.
_amp_scale_adjustment
=
amp_scale_adjustment
self
.
_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
self
.
_has_overflow
=
False
self
.
multi_tensor_lamb_compute_update_term
=
distributed_lamb_cuda
.
multi_tensor_lamb_compute_update_term
...
...
@@ -128,9 +106,10 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
import
amp_C
self
.
multi_tensor_l2norm
=
amp_C
.
multi_tensor_l2norm
self
.
_grad_averaging
=
grad_averaging
self
.
_adam_w_mode
=
1
if
adam_w_mode
else
0
self
.
_use_nvlamb
=
use_nvlamb
self
.
_
clip_grad_norm
=
clip_grad_norm
self
.
_
step_supports_amp_scaling
=
step_supports_amp_scaling
self
.
_is_accumulation_step
=
False
self
.
_last_step
=
False
self
.
_overlap_reductions
=
overlap_reductions
...
...
@@ -138,44 +117,138 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self
.
_num_blocks
=
dwu_num_blocks
self
.
_num_chunks
=
dwu_num_chunks
self
.
_e5m2_allgather
=
e5m2_allgather
self
.
_verbose
=
verbose
self
.
_clip_after_ar
=
clip_after_ar
self
.
_L2_grad_norm
=
None
self
.
_current_process_group
=
c10d
.
_get_default_group
()
self
.
_available_ranks
=
list
(
c10d
.
_pg_group_ranks
[
self
.
_current_process_group
].
keys
())
self
.
_group_size
=
torch
.
cuda
.
device_count
()
if
dwu_group_size
<=
0
else
dwu_group_size
self
.
_world_size
=
torch
.
distributed
.
get_world_size
()
self
.
_num_groups
=
self
.
_world_size
//
self
.
_group_size
self
.
_rank_in_group
=
torch
.
distributed
.
get_rank
()
%
self
.
_group_size
self
.
_lr
=
torch
.
tensor
(
0.0
,
dtype
=
torch
.
float32
,
device
=
'cuda'
)
self
.
_resume_from_checkpoint
=
False
self
.
_step
=
torch
.
cuda
.
IntTensor
([
0
])
# Master weight, moment, gradient buffers
self
.
_fp32_p
,
self
.
_fp32_m
,
self
.
_fp32_v
,
self
.
_fp16_p
,
self
.
_fp16_g
=
None
,
None
,
None
,
None
,
None
import
inspect
#assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option"
self
.
_num_rs_pg
=
dwu_num_rs_pg
self
.
_num_ar_pg
=
dwu_num_ar_pg
self
.
_num_ag_pg
=
dwu_num_ag_pg
if
self
.
_num_groups
>
1
:
self
.
_ar_pg
=
[]
for
dev_i
in
range
(
self
.
_group_size
):
ranks
=
[
dev_i
+
j
*
self
.
_group_size
for
j
in
range
(
self
.
_num_groups
)]
for
i
in
range
(
self
.
_num_ar_pg
):
if
self
.
_verbose
:
print
(
f
"creating new group
{
i
}
:
{
ranks
}
"
)
grp
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
if
grp
!=
torch
.
distributed
.
GroupMember
.
NON_GROUP_MEMBER
:
if
self
.
_verbose
:
print
(
f
"group
{
i
}
: init barrier (device:
{
torch
.
cuda
.
current_device
()
}
)"
)
torch
.
distributed
.
barrier
(
group
=
grp
,
device_ids
=
[
torch
.
cuda
.
current_device
()])
if
self
.
_verbose
:
print
(
f
"created new group
{
i
}
"
)
if
torch
.
distributed
.
get_rank
()
in
ranks
:
self
.
_ar_pg
.
append
(
grp
)
self
.
_ar_st
=
[
torch
.
cuda
.
Stream
()
for
_
in
range
(
self
.
_num_ar_pg
)]
#for ar_pg in self._ar_pg:
# torch.distributed.all_reduce(self._overflow_buf,group=ar_pg)
rs_ranks
=
[]
for
group_i
in
range
(
self
.
_num_groups
):
rs_ranks
.
append
([
group_i
*
self
.
_group_size
+
j
for
j
in
range
(
self
.
_group_size
)])
self
.
_rs_pg
=
[]
for
group_i
in
range
(
self
.
_num_groups
):
ranks
=
rs_ranks
[
group_i
]
for
i
in
range
(
self
.
_num_rs_pg
):
grp
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
if
torch
.
distributed
.
get_rank
()
in
ranks
:
self
.
_rs_pg
.
append
(
grp
)
l2_grad_norm_pg
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
if
torch
.
distributed
.
get_rank
()
in
ranks
:
self
.
_l2_grad_norm_pg
=
l2_grad_norm_pg
#torch.distributed.all_reduce(self._overflow_buf,group=self._l2_grad_norm_pg)
self
.
_rs_st
=
[
torch
.
cuda
.
Stream
()
for
_
in
range
(
self
.
_num_rs_pg
)]
#for rs_pg in self._rs_pg:
# torch.distributed.all_reduce(self._overflow_buf,group=rs_pg)
if
self
.
_num_ag_pg
==
0
:
self
.
_ag_pg
=
self
.
_rs_pg
self
.
_ag_st
=
self
.
_rs_st
self
.
_num_ag_pg
=
self
.
_num_rs_pg
else
:
self
.
_ag_pg
=
[]
for
group_i
in
range
(
self
.
_num_groups
):
ranks
=
rs_ranks
[
group_i
]
for
i
in
range
(
self
.
_num_ag_pg
):
grp
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
if
torch
.
distributed
.
get_rank
()
in
ranks
:
self
.
_ag_pg
.
append
(
grp
)
self
.
_ag_st
=
[
torch
.
cuda
.
Stream
()
for
_
in
range
(
self
.
_num_ag_pg
)]
#for ag_pg in self._ag_pg:
# torch.distributed.all_reduce(self._overflow_buf,group=ag_pg)
self
.
_l2_grad_norm_st
=
torch
.
cuda
.
Stream
()
self
.
_completion_st
=
torch
.
cuda
.
Stream
()
self
.
_step
.
record_stream
(
self
.
_completion_st
)
self
.
_reductions_works
=
[
None
]
*
self
.
_num_blocks
self
.
_allgather_works
=
[
None
]
*
self
.
_num_blocks
self
.
_one
=
torch
.
cuda
.
IntTensor
([
1
])
self
.
_first_step
=
True
self
.
_lazy_init_stage1_done
,
self
.
_lazy_init_stage2_done
=
False
,
False
self
.
_param_order
=
self
.
AtomicCounter
()
def
_lazy_init_stage1
(
self
):
if
self
.
_lazy_init_stage1_done
:
return
p_offset
=
0
p_i
=
0
self
.
_model_params
=
[]
self
.
_grads_info
=
[]
self
.
_grad_accs
=
[]
self
.
_group_properties
=
[]
for
group
in
self
.
param_groups
:
prev
=
None
beta1
,
beta2
=
group
[
'betas'
]
beta3
=
1.0
-
beta1
if
self
.
_grad_averaging
else
1.0
bias_correction
=
1
if
group
[
'bias_correction'
]
else
0
eps
=
group
[
'eps'
]
weight_decay
=
group
[
'weight_decay'
]
for
p
in
group
[
'params'
]:
torch
.
distributed
.
broadcast
(
p
,
0
)
torch
.
distributed
.
broadcast
(
p
,
0
)
if
not
p
.
requires_grad
:
continue
self
.
_model_params
.
append
(
p
)
self
.
_group_properties
.
append
((
group
[
'
weight_decay
'
]
,
1
if
group
[
'
bias_correction
'
]
else
0
,
weight_decay
,
bias_correction
,
beta1
,
beta2
,
1.0
-
beta1
if
grad_averaging
else
1.0
,
group
[
'
eps
'
]
beta3
,
eps
))
p_grads_size
=
p
.
numel
()
def
wrapper
(
param
,
param_i
,
param_grads_size
,
param_offset
):
def
wrapper
(
param
,
param_i
):
param_tmp
=
param
.
expand_as
(
param
)
grad_acc
=
param_tmp
.
grad_fn
.
next_functions
[
0
][
0
]
def
allreduce_hook
(
*
unused
):
self
.
_do_overlapped_reduction
(
param_i
,
param_grads_size
,
param_offset
,
param
)
if
self
.
_first_step
:
# first time
self
.
_param_order
.
add
(
param_i
)
else
:
idx
=
self
.
_param_order
.
order
.
index
(
param_i
)
self
.
_do_overlapped_reduction
(
idx
,
param
)
grad_acc
.
register_hook
(
allreduce_hook
)
self
.
_grad_accs
.
append
(
grad_acc
)
self
.
_grads_info
.
append
({
"param_grads_size"
:
p_grads_size
,
"param_offset"
:
p_offset
})
wrapper
(
p
,
p_i
,
p_grads_size
,
p_offset
)
wrapper
(
p
,
p_i
)
p_offset
+=
p_grads_size
# Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
# RNN is one example of consecutive parameters:
...
...
@@ -184,7 +257,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
p_offset
=
((
p_offset
+
63
)
//
64
)
*
64
prev
=
p
p_i
+=
1
self
.
_grads_generated
=
[
False
]
*
len
(
self
.
_
grads_info
)
self
.
_grads_generated
=
[
False
]
*
len
(
self
.
_
model_params
)
self
.
_grads_fp16
,
self
.
_grads_fp32
=
[],
[]
if
self
.
_overlap_reductions
:
self
.
_current_block
=
self
.
_num_blocks
...
...
@@ -196,31 +269,21 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self
.
_block_size
=
self
.
_total_param_size
//
self
.
_num_blocks
self
.
_chunk_size
=
self
.
_block_size
//
self
.
_num_chunks
self
.
_shard_size
=
self
.
_chunk_size
//
self
.
_group_size
print
(
"self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._chunk_size=%d, self._shard_size=%d"
%
(
self
.
_net_total_param_size
,
self
.
_total_param_size
,
dwu_min_page_size
,
self
.
_block_size
,
self
.
_chunk_size
,
self
.
_shard_size
))
self
.
_low_param_i
=
[
0
]
*
self
.
_num_blocks
for
block_id
in
range
(
self
.
_num_blocks
-
1
,
-
1
,
-
1
):
p_i
=
len
(
self
.
_grads_info
)
-
1
while
p_i
>
0
and
self
.
_grads_info
[
p_i
][
"param_offset"
]
>
block_id
*
self
.
_block_size
:
p_i
-=
1
self
.
_low_param_i
[
block_id
]
=
p_i
print
(
self
.
_low_param_i
)
#print("self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._chunk_size=%d, self._shard_size=%d" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._chunk_size,self._shard_size))
self
.
_flat_grads
=
torch
.
zeros
([
self
.
_total_param_size
],
dtype
=
torch
.
float16
,
device
=
'cuda'
)
self
.
_new_params
=
torch
.
zeros
([
self
.
_total_param_size
],
dtype
=
torch
.
uint8
if
self
.
_e5m2_allgather
else
torch
.
float16
,
device
=
'cuda'
)
self
.
_mega_shard_size
=
self
.
_num_blocks
*
self
.
_num_chunks
*
self
.
_shard_size
self
.
_fp32_p
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
float32
,
device
=
'cuda'
)
self
.
_fp32_m
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
float32
,
device
=
'cuda'
)
self
.
_fp32_v
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
float32
,
device
=
'cuda'
)
self
.
_fp32_u
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
float32
,
device
=
'cuda'
)
# initialize master weights, moments buffers if not loaded from checkpoint
if
self
.
_fp32_p
is
None
:
self
.
_fp32_p
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
float32
,
device
=
'cuda'
)
self
.
_fp32_m
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
float32
,
device
=
'cuda'
)
self
.
_fp32_v
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
float32
,
device
=
'cuda'
)
self
.
_fp32_u
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
float32
,
device
=
'cuda'
)
# FIXME: Rethink fp16 label since it's either uint8 or fp16
self
.
_fp16_p
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
uint8
if
self
.
_e5m2_allgather
else
torch
.
float16
,
device
=
'cuda'
)
self
.
_fp16_g
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
float16
,
device
=
'cuda'
)
self
.
_individual_flat_grads
=
[]
for
p_i
,
(
grads_info
,
p
)
in
enumerate
(
zip
(
self
.
_grads_info
,
self
.
_model_params
)):
self
.
_individual_flat_grads
.
append
(
self
.
_flat_grads
[
grads_info
[
"param_offset"
]:
grads_info
[
"param_offset"
]
+
grads_info
[
"param_grads_size"
]].
view_as
(
p
))
def
_flat_split
(
p
):
def
__blockify
(
p
):
return
[
p
[
block_id
*
self
.
_block_size
:(
block_id
+
1
)
*
self
.
_block_size
]
for
block_id
in
range
(
self
.
_num_blocks
)]
...
...
@@ -262,6 +325,45 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self
.
_fp16_p_blocks
,
self
.
_fp16_p_chunks
=
_packed_split
(
self
.
_fp16_p
)
self
.
_fp16_g_blocks
,
self
.
_fp16_g_chunks
=
_packed_split
(
self
.
_fp16_g
)
self
.
_lazy_init_stage1_done
=
True
def
_lazy_init_stage2
(
self
):
if
self
.
_lazy_init_stage2_done
:
return
self
.
_param_order
.
order
.
reverse
()
# re-order model_params, grad_accs, group_properties lists
self
.
_model_params
=
[
self
.
_model_params
[
i
]
for
i
in
self
.
_param_order
.
order
]
self
.
_grad_accs
=
[
self
.
_grad_accs
[
i
]
for
i
in
self
.
_param_order
.
order
]
self
.
_group_properties
=
[
self
.
_group_properties
[
i
]
for
i
in
self
.
_param_order
.
order
]
# re-collect grads info (size, offset) after ordering
prev
=
None
p_offset
=
0
self
.
_grads_info
=
[]
self
.
_individual_flat_grads
=
[]
for
i
,
p
in
enumerate
(
self
.
_model_params
):
p_grads_size
=
p
.
numel
()
self
.
_grads_info
.
append
({
"param_grads_size"
:
p_grads_size
,
"param_offset"
:
p_offset
})
self
.
_individual_flat_grads
.
append
(
self
.
_flat_grads
[
p_offset
:
p_offset
+
p_grads_size
].
view_as
(
p
))
# for the first iteration
self
.
_do_overlapped_reduction
(
i
,
p
)
p_offset
+=
p_grads_size
# Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
# RNN is one example of consecutive parameters:
# (weight_ih, weight_hh, bias_ih, bias_hh)
if
prev
is
not
None
and
(
prev
.
data_ptr
()
+
prev
.
numel
()
*
prev
.
element_size
()
!=
p
.
data_ptr
()):
p_offset
=
((
p_offset
+
63
)
//
64
)
*
64
prev
=
p
self
.
_low_param_i
=
[
0
]
*
self
.
_num_blocks
for
block_id
in
range
(
self
.
_num_blocks
-
1
,
-
1
,
-
1
):
p_i
=
len
(
self
.
_grads_info
)
-
1
while
p_i
>
0
and
self
.
_grads_info
[
p_i
][
"param_offset"
]
>
block_id
*
self
.
_block_size
:
p_i
-=
1
self
.
_low_param_i
[
block_id
]
=
p_i
#print("self._low_param_i", self._low_param_i)
# This paragraph does two things:
# 1) Copy model parameters into master buffer
# 2) Create tensor lists for unpacking new parameter tensor after all-gather
...
...
@@ -274,7 +376,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self
.
_contrib_model_param_for_norm_fp16
=
[]
self
.
_contrib_model_param_for_norm_fp32
=
[]
self
.
_contrib_model_param_for_norm_is_fp16
=
[]
self
.
_model_param_is_contrib
=
[
False
]
*
self
.
_model_params_num
self
.
_model_param_is_contrib
=
[
]
self
.
_contrib_group_properties
=
[]
for
shard_id
in
range
(
self
.
_group_size
):
for
block_id
in
range
(
self
.
_num_blocks
):
...
...
@@ -297,7 +399,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
else
:
self
.
_packed_flat_to_model_params_fp32
.
append
(
(
new_param_packed_fragment
,
model_param_fragment
)
)
if
shard_id
==
self
.
_rank_in_group
:
self
.
_model_param_is_contrib
[
param_i
]
=
True
self
.
_model_param_is_contrib
.
append
(
param_i
)
# copy model parameters into master buffer
master_param_fragment
=
self
.
_fp32_p_chunks
[
block_id
][
chunk_id
][
shard_offset
:
shard_offset
+
grad_length
]
opti_state_m_fragment
=
self
.
_fp32_m_chunks
[
block_id
][
chunk_id
][
shard_offset
:
shard_offset
+
grad_length
]
...
...
@@ -306,7 +408,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
opti_state_g_fragment
=
self
.
_fp16_g_chunks
[
block_id
][
chunk_id
][
shard_offset
:
shard_offset
+
grad_length
]
opti_state_p_fragment
=
self
.
_fp16_p_chunks
[
block_id
][
chunk_id
][
shard_offset
:
shard_offset
+
grad_length
]
#print("model_param_fragment.size()=%s, new_param_packed_fragment.size()=%s, master_param_fragment.size()=%s" % (str(model_param_fragment.size()), str(new_param_packed_fragment.size()), str(master_param_fragment.size())))
master_param_fragment
.
copy_
(
model_param_fragment
)
if
not
self
.
_resume_from_checkpoint
:
master_param_fragment
.
copy_
(
model_param_fragment
)
self
.
_contrib_group_properties
.
append
(
group_props
)
self
.
_contrib_tensor_list
.
append
((
master_param_fragment
,
opti_state_m_fragment
,
opti_state_v_fragment
,
opti_state_u_fragment
,
opti_state_g_fragment
,
opti_state_p_fragment
))
# p, m, v, u, g, p_copy
self
.
_contrib_update_frag_for_norm
.
append
(
opti_state_u_fragment
)
...
...
@@ -322,7 +425,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
if
len
(
self
.
_contrib_model_param_for_norm_fp32
)
==
0
:
self
.
_contrib_model_param_for_norm_fp32
=
None
self
.
_contrib_model_param_for_norm_is_fp32
=
torch
.
tensor
([
not
is_fp16
for
is_fp16
in
self
.
_contrib_model_param_for_norm_is_fp16
],
dtype
=
torch
.
bool
,
device
=
'cuda'
)
self
.
_contrib_model_param_for_norm_is_fp16
=
torch
.
tensor
([
is_fp16
for
is_fp16
in
self
.
_contrib_model_param_for_norm_is_fp16
],
dtype
=
torch
.
bool
,
device
=
'cuda'
)
self
.
_
model_param_is_contrib
=
torch
.
tensor
(
self
.
_model_param_is_contrib
,
dtype
=
torch
.
bool
,
device
=
'cuda'
)
self
.
_
offsets
=
torch
.
tensor
(
self
.
_model_param_is_contrib
,
dtype
=
torch
.
int64
,
device
=
'cuda'
)
p
,
m
,
v
,
u
,
g
,
p_copy
=
list
(
zip
(
*
self
.
_contrib_tensor_list
))
self
.
_contrib_compute_update_term_tensor_list
=
[
g
,
p
,
m
,
v
,
u
]
...
...
@@ -340,62 +443,10 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self
.
_packed_flat_to_model_params_fp16
=
list
(
zip
(
*
self
.
_packed_flat_to_model_params_fp16
))
if
len
(
self
.
_packed_flat_to_model_params_fp16
)
>
0
else
None
self
.
_packed_flat_to_model_params_fp32
=
list
(
zip
(
*
self
.
_packed_flat_to_model_params_fp32
))
if
len
(
self
.
_packed_flat_to_model_params_fp32
)
>
0
else
None
self
.
_num_rs_pg
=
dwu_num_rs_pg
self
.
_num_ar_pg
=
dwu_num_ar_pg
self
.
_num_ag_pg
=
dwu_num_ag_pg
if
self
.
_num_groups
>
1
:
self
.
_ar_pg
=
[]
for
dev_i
in
range
(
self
.
_group_size
):
ranks
=
[
dev_i
+
j
*
self
.
_group_size
for
j
in
range
(
self
.
_num_groups
)]
for
i
in
range
(
self
.
_num_ar_pg
):
grp
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
if
torch
.
distributed
.
get_rank
()
in
ranks
:
self
.
_ar_pg
.
append
(
grp
)
self
.
_ar_st
=
[
torch
.
cuda
.
Stream
()
for
_
in
range
(
self
.
_num_ar_pg
)]
for
ar_pg
in
self
.
_ar_pg
:
torch
.
distributed
.
all_reduce
(
self
.
_overflow_buf
,
group
=
ar_pg
)
rs_ranks
=
[]
for
group_i
in
range
(
self
.
_num_groups
):
rs_ranks
.
append
([
group_i
*
self
.
_group_size
+
j
for
j
in
range
(
self
.
_group_size
)])
self
.
_rs_pg
=
[]
for
group_i
in
range
(
self
.
_num_groups
):
ranks
=
rs_ranks
[
group_i
]
for
i
in
range
(
self
.
_num_rs_pg
):
grp
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
if
torch
.
distributed
.
get_rank
()
in
ranks
:
self
.
_rs_pg
.
append
(
grp
)
l2_grad_norm_pg
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
if
torch
.
distributed
.
get_rank
()
in
ranks
:
self
.
_l2_grad_norm_pg
=
l2_grad_norm_pg
torch
.
distributed
.
all_reduce
(
self
.
_overflow_buf
,
group
=
self
.
_l2_grad_norm_pg
)
self
.
_rs_st
=
[
torch
.
cuda
.
Stream
()
for
_
in
range
(
self
.
_num_rs_pg
)]
for
rs_pg
in
self
.
_rs_pg
:
torch
.
distributed
.
all_reduce
(
self
.
_overflow_buf
,
group
=
rs_pg
)
if
self
.
_num_ag_pg
==
0
:
self
.
_ag_pg
=
self
.
_rs_pg
self
.
_ag_st
=
self
.
_rs_st
self
.
_num_ag_pg
=
self
.
_num_rs_pg
else
:
self
.
_ag_pg
=
[]
for
group_i
in
range
(
self
.
_num_groups
):
ranks
=
rs_ranks
[
group_i
]
for
i
in
range
(
self
.
_num_ag_pg
):
grp
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
if
torch
.
distributed
.
get_rank
()
in
ranks
:
self
.
_ag_pg
.
append
(
grp
)
self
.
_ag_st
=
[
torch
.
cuda
.
Stream
()
for
_
in
range
(
self
.
_num_ag_pg
)]
for
ag_pg
in
self
.
_ag_pg
:
torch
.
distributed
.
all_reduce
(
self
.
_overflow_buf
,
group
=
ag_pg
)
self
.
_l2_grad_norm_st
=
torch
.
cuda
.
Stream
()
self
.
_completion_st
=
torch
.
cuda
.
Stream
()
self
.
_lazy_init_stage2_done
=
True
self
.
_reductions_works
=
[
None
]
*
self
.
_num_blocks
self
.
_allgather_works
=
[
None
]
*
self
.
_num_blocks
def
_init_everything
(
self
):
if
not
self
.
_init_done
:
self
.
__first_step_init__
(
**
self
.
_init_args
)
self
.
_init_done
=
True
self
.
complete_reductions
()
self
.
_first_step
=
False
def
set_is_accumulation_step
(
self
,
is_accumulation_step
):
self
.
_is_accumulation_step
=
is_accumulation_step
...
...
@@ -420,40 +471,87 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
return
flush_block
def
_pipeline_block_reductions
(
self
,
block_id
):
self
.
_flatten_grad_mt
(
1.0
/
self
.
_world_size
)
# Reduction within each node
# Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
# The output format is the same as the fp32 master parameters
works
=
[
None
]
*
self
.
_num_chunks
for
chunk_id
in
range
(
self
.
_num_chunks
):
glob_chunk_id
=
block_id
*
self
.
_num_chunks
+
chunk_id
rs_stream
=
self
.
_rs_st
[
glob_chunk_id
%
self
.
_num_rs_pg
]
rs_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
rs_stream
):
works
[
chunk_id
]
=
torch
.
distributed
.
reduce_scatter
(
self
.
_fp16_g_chunks
[
block_id
][
chunk_id
],
self
.
_flat_grads_shards
[
block_id
][
chunk_id
],
group
=
self
.
_rs_pg
[
glob_chunk_id
%
self
.
_num_rs_pg
],
async_op
=
True
,
no_copy
=
True
)
# Reduction across nodes for each rank
if
self
.
_num_groups
>
1
:
if
self
.
_clip_after_ar
:
self
.
_flatten_grad_mt
(
1.0
/
self
.
_world_size
)
# Reduction within each node
# Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
# The output format is the same as the fp32 master parameters
works
=
[
None
]
*
self
.
_num_chunks
for
chunk_id
in
range
(
self
.
_num_chunks
):
glob_chunk_id
=
block_id
*
self
.
_num_chunks
+
chunk_id
ar_stream
=
self
.
_ar_st
[
glob_chunk_id
%
self
.
_num_ar_pg
]
with
torch
.
cuda
.
stream
(
ar_stream
):
works
[
chunk_id
].
wait
()
works
[
chunk_id
]
=
torch
.
distributed
.
all_reduce
(
self
.
_fp16_g_chunks
[
block_id
][
chunk_id
],
group
=
self
.
_ar_pg
[
glob_chunk_id
%
self
.
_num_ar_pg
],
async_op
=
True
)
self
.
_reductions_works
[
block_id
]
=
works
# Compute L2 grad norm
if
block_id
==
0
:
rs_stream
=
self
.
_rs_st
[
glob_chunk_id
%
self
.
_num_rs_pg
]
rs_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
rs_stream
):
works
[
chunk_id
]
=
torch
.
distributed
.
reduce_scatter
(
self
.
_fp16_g_chunks
[
block_id
][
chunk_id
],
self
.
_flat_grads_shards
[
block_id
][
chunk_id
],
group
=
self
.
_rs_pg
[
glob_chunk_id
%
self
.
_num_rs_pg
],
async_op
=
True
,
no_copy
=
False
)
# Reduction across nodes for each rank
if
self
.
_num_groups
>
1
:
for
chunk_id
in
range
(
self
.
_num_chunks
):
glob_chunk_id
=
block_id
*
self
.
_num_chunks
+
chunk_id
ar_stream
=
self
.
_ar_st
[
glob_chunk_id
%
self
.
_num_ar_pg
]
with
torch
.
cuda
.
stream
(
ar_stream
):
works
[
chunk_id
].
wait
()
works
[
chunk_id
]
=
torch
.
distributed
.
all_reduce
(
self
.
_fp16_g_chunks
[
block_id
][
chunk_id
],
group
=
self
.
_ar_pg
[
glob_chunk_id
%
self
.
_num_ar_pg
],
async_op
=
True
)
self
.
_reductions_works
[
block_id
]
=
works
# Compute L2 grad norm
if
block_id
==
0
:
with
torch
.
cuda
.
stream
(
self
.
_l2_grad_norm_st
):
for
block_id
in
range
(
self
.
_num_blocks
):
for
chunk_id
in
range
(
self
.
_num_chunks
):
self
.
_reductions_works
[
block_id
][
chunk_id
].
wait
()
# Since the packed format is contiguous after reductions, only one norm is needed
l2_grad_norm_sq
=
torch
.
empty
([
1
],
device
=
'cuda'
)
l2_grad_norm_sq
=
self
.
_fp16_g
.
norm
(
dtype
=
torch
.
float32
,
p
=
2
)
**
2
torch
.
distributed
.
all_reduce
(
l2_grad_norm_sq
,
group
=
self
.
_l2_grad_norm_pg
)
self
.
_L2_grad_norm
=
l2_grad_norm_sq
.
sqrt
()
else
:
# Copy model grads to flat grads buffer
self
.
_flatten_grad_mt
(
1.0
)
# Compute L2 grad norm
self
.
_l2_grad_norm_st
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
self
.
_l2_grad_norm_st
):
self
.
_L2_grad_norm
=
self
.
_flat_grads
.
norm
(
dtype
=
torch
.
float16
,
p
=
2
).
float
()
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_l2_grad_norm_st
)
# Apply clipping & pre-reduction scaling on grads
loss_scale
=
self
.
global_scale
max_grad_norm
=
loss_scale
*
self
.
defaults
[
'max_grad_norm'
]
coeff
=
max_grad_norm
/
(
1e-6
+
self
.
L2_grad_norm
)
coeff
=
(
coeff
>
1
)
*
self
.
_one
+
(
coeff
<=
1
)
*
coeff
tmp
=
torch
.
cat
(((
self
.
_one
),
(
coeff
)))
index
=
(
coeff
+
1
>
coeff
).
int
()
scale
=
tmp
.
index_select
(
0
,
index
).
half
()
/
self
.
_world_size
self
.
_flat_grads
.
mul_
(
scale
)
# Reduction within each node
# Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
# The output format is the same as the fp32 master parameters
works
=
[
None
]
*
self
.
_num_chunks
for
chunk_id
in
range
(
self
.
_num_chunks
):
glob_chunk_id
=
block_id
*
self
.
_num_chunks
+
chunk_id
rs_stream
=
self
.
_rs_st
[
glob_chunk_id
%
self
.
_num_rs_pg
]
rs_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
rs_stream
.
wait_stream
(
self
.
_l2_grad_norm_st
)
with
torch
.
cuda
.
stream
(
rs_stream
):
works
[
chunk_id
]
=
torch
.
distributed
.
reduce_scatter
(
self
.
_fp16_g_chunks
[
block_id
][
chunk_id
],
self
.
_flat_grads_shards
[
block_id
][
chunk_id
],
group
=
self
.
_rs_pg
[
glob_chunk_id
%
self
.
_num_rs_pg
],
async_op
=
True
,
no_copy
=
False
)
# Reduction across nodes for each rank
if
self
.
_num_groups
>
1
:
for
chunk_id
in
range
(
self
.
_num_chunks
):
glob_chunk_id
=
block_id
*
self
.
_num_chunks
+
chunk_id
ar_stream
=
self
.
_ar_st
[
glob_chunk_id
%
self
.
_num_ar_pg
]
with
torch
.
cuda
.
stream
(
ar_stream
):
works
[
chunk_id
].
wait
()
works
[
chunk_id
]
=
torch
.
distributed
.
all_reduce
(
self
.
_fp16_g_chunks
[
block_id
][
chunk_id
],
group
=
self
.
_ar_pg
[
glob_chunk_id
%
self
.
_num_ar_pg
],
async_op
=
True
)
self
.
_reductions_works
[
block_id
]
=
works
if
block_id
==
0
:
for
block_id
in
range
(
self
.
_num_blocks
):
for
chunk_id
in
range
(
self
.
_num_chunks
):
self
.
_reductions_works
[
block_id
][
chunk_id
].
wait
()
# Since the packed format is contiguous after reductions, only one norm is needed
l2_grad_norm_sq
=
torch
.
empty
([
1
],
device
=
'cuda'
)
l2_grad_norm_sq
=
self
.
_fp16_g
.
norm
(
dtype
=
torch
.
float32
,
p
=
2
)
**
2
torch
.
distributed
.
all_reduce
(
l2_grad_norm_sq
,
group
=
self
.
_l2_grad_norm_pg
)
self
.
_L2_grad_norm
=
l2_grad_norm_sq
.
sqrt
().
item
()
def
__compute_contrib_param_norm
(
self
):
if
self
.
_contrib_model_param_for_norm_fp16
is
not
None
and
self
.
_contrib_model_param_for_norm_fp32
is
not
None
:
...
...
@@ -471,24 +569,32 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
def
__compute_contrib_update_norm
(
self
):
l2_norm
=
torch
.
zeros
(
size
=
[
self
.
_model_params_num
],
dtype
=
torch
.
float32
,
device
=
'cuda'
)
local_contrib_l2_norm
=
multi_tensor_applier
(
self
.
multi_tensor_l2norm
,
self
.
_overflow_buf
,
[
self
.
_contrib_update_frag_for_norm
],
True
)[
1
]
**
2
l2_norm
.
masked_
scatter_
(
self
.
_model_param_is_contrib
,
local_contrib_l2_norm
)
l2_norm
.
scatter_
(
dim
=
0
,
index
=
self
.
_offsets
,
src
=
local_contrib_l2_norm
)
torch
.
distributed
.
all_reduce
(
l2_norm
,
group
=
self
.
_ag_pg
[
0
])
l2_norm
=
torch
.
sqrt
(
l2_norm
)
return
l2_norm
.
masked_select
(
self
.
_model_param_is_contrib
)
return
l2_norm
def
_pipeline_step
(
self
):
# If self._clip_grad_norm is False, we assume gradient clipping already
# happened outside the optimizer and self._global_scale has already
# been set to the combined scale, i.e. it's no longer the current loss
# scale used by the loss scaler.
# For model parallelism cases in which we need to get global gradient
# norm via all-reduce outside the optimizer to do the clipping.
combined_scale
=
self
.
global_scale
max_grad_norm
=
self
.
defaults
[
'max_grad_norm'
]
global_scale
=
self
.
global_scale
# if clip before ar, set max_grad_norm to 0
max_grad_norm
=
self
.
defaults
[
'max_grad_norm'
]
*
self
.
_clip_after_ar
self
.
_completion_st
.
wait_stream
(
self
.
_l2_grad_norm_st
)
global_grad_norm
=
self
.
L2_grad_norm
if
self
.
_clip_grad_norm
and
max_grad_norm
>
0
and
math
.
isfinite
(
global_grad_norm
):
combined_scale
=
max_grad_norm
/
(
global_grad_norm
/
self
.
global_scale
+
1e-6
)
combined_scale
=
self
.
global_scale
/
min
(
1
,
combined_scale
)
# check global_grad_norm and fill overflow_buf
is_finite
=
(
global_grad_norm
+
1
>
global_grad_norm
).
int
()
self
.
_overflow_buf
=
self
.
_one
*
(
is_finite
^
self
.
_one
)
# toggle between 0 and 1
torch
.
distributed
.
all_reduce
(
is_finite
,
op
=
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
self
.
_current_process_group
)
torch
.
distributed
.
all_reduce
(
self
.
_overflow_buf
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
self
.
_current_process_group
)
# increment step counter if no overflow
self
.
_step
+=
is_finite
self
.
_completion_st
.
wait_stream
(
torch
.
cuda
.
current_stream
())
self
.
_completion_st
.
wait_stream
(
self
.
_l2_grad_norm_st
)
# Call step kernel once per step
# Call all-gather once per step
...
...
@@ -504,21 +610,25 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self
.
_contrib_beta2
,
self
.
_contrib_beta3
,
self
.
_contrib_bias_correction
,
self
.
param_groups
[
0
][
'
step
'
]
,
self
.
_
step
,
self
.
_contrib_epsilon
,
self
.
_adam_w_mode
,
self
.
_contrib_weight_decay
,
combined_scale
)
global_scale
,
global_grad_norm
,
max_grad_norm
)
upd_norm
=
self
.
__compute_contrib_update_norm
()
multi_tensor_applier
(
self
.
multi_tensor_lamb_update_weights
,
self
.
_overflow_buf
,
self
.
_contrib_update_weights_tensor_list
,
# u, p, p_copy
param_norm
,
upd_norm
,
self
.
param_groups
[
0
][
'lr'
],
self
.
_offsets
,
self
.
_lr
,
self
.
_contrib_weight_decay
,
global_grad_norm
,
self
.
_use_nvlamb
)
torch
.
distributed
.
all_gather
(
self
.
_new_params_mega_shards
,
self
.
_fp16_p
,
group
=
self
.
_ag_pg
[
0
],
no_copy
=
Tru
e
)
torch
.
distributed
.
all_gather
(
self
.
_new_params_mega_shards
,
self
.
_fp16_p
,
group
=
self
.
_ag_pg
[
0
],
no_copy
=
Fals
e
)
def
_flatten_grad_mt
(
self
,
scale
):
if
len
(
self
.
_grads_fp16
)
>
0
:
...
...
@@ -538,8 +648,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
scale
)
self
.
_grads_fp32
=
[]
def
_do_overlapped_reduction
(
self
,
param_i
,
param_grads_size
,
param_offset
,
param
):
self
.
_init_everything
()
def
_do_overlapped_reduction
(
self
,
param_i
,
param
):
if
not
self
.
_is_accumulation_step
:
# handle overlapped reductions
if
param
.
dtype
==
torch
.
float16
:
...
...
@@ -547,12 +656,13 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
else
:
self
.
_grads_fp32
.
append
(
(
param
.
grad
,
self
.
_individual_flat_grads
[
param_i
])
)
self
.
_grads_generated
[
param_i
]
=
True
if
self
.
_overlap_reductions
and
not
self
.
_last_step
:
flush_block
=
self
.
_get_flush_block
()
while
flush_block
:
block_id
=
flush_block
[
0
]
//
self
.
_block_size
self
.
_pipeline_block_reductions
(
block_id
)
if
not
self
.
_first_step
and
not
self
.
_last_step
:
if
self
.
_overlap_reductions
:
flush_block
=
self
.
_get_flush_block
()
while
flush_block
:
block_id
=
flush_block
[
0
]
//
self
.
_block_size
self
.
_pipeline_block_reductions
(
block_id
)
flush_block
=
self
.
_get_flush_block
()
def
set_global_scale
(
self
,
global_scale
):
"""Set global scale.
...
...
@@ -565,14 +675,12 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
@
property
def
L2_grad_norm
(
self
):
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_l2_grad_norm_st
)
return
self
.
_L2_grad_norm
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_l2_grad_norm_st
)
return
self
.
_L2_grad_norm
def
complete_reductions
(
self
):
"""Complete reductions if full pipeline is not selected or overlap is not allowed.
"""
self
.
_init_everything
()
if
self
.
_last_step
:
# zero out gradients that have not been completed yet
for
param_i
,
grad_generated
in
enumerate
(
self
.
_grads_generated
):
...
...
@@ -583,7 +691,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self
.
_flat_grads
[
param_offset
:
param_offset
+
param_size
].
zero_
()
self
.
_grads_generated
[
param_i
]
=
True
if
self
.
_last_step
or
not
self
.
_overlap_reductions
:
if
self
.
_first_step
or
self
.
_last_step
or
not
self
.
_overlap_reductions
:
# nothing done so far, run full pipeline after reductions
for
block_id
in
range
(
self
.
_num_blocks
-
1
,
-
1
,
-
1
):
self
.
_pipeline_block_reductions
(
block_id
)
...
...
@@ -593,24 +701,23 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self
.
_current_block
=
self
.
_num_blocks
self
.
_grads_generated
=
[
False
]
*
len
(
self
.
_grads_info
)
def
step
(
self
,
closure
=
None
):
def
step
(
self
,
closure
=
None
,
grad_scaler
=
None
):
loss
=
None
if
closure
is
not
None
:
loss
=
closure
()
# assume same step across group now to simplify things
# per parameter step can be easily support by making it tensor, or pass list into kernel
for
param_group
in
self
.
param_groups
:
if
'step'
in
param_group
:
param_group
[
'step'
]
+=
1
else
:
param_group
[
'step'
]
=
1
self
.
_pipeline_step
()
if
grad_scaler
is
not
None
:
found_inf
=
self
.
_overflow_buf
.
float
()
optimizer_state
=
grad_scaler
.
_per_optimizer_states
[
id
(
self
)]
current_device
=
torch
.
device
(
'cuda'
,
torch
.
cuda
.
current_device
())
optimizer_state
[
"found_inf_per_device"
][
current_device
]
=
found_inf
self
.
_completion_st
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
self
.
_completion_st
):
# Copy self._new_params to model params
self
.
_overflow_buf
.
zero_
()
with
torch
.
no_grad
():
if
self
.
_packed_flat_to_model_params_fp16
is
not
None
:
multi_tensor_applier
(
...
...
@@ -630,4 +737,42 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
return
loss
def
state_dict
(
self
):
"""
Returns a dict containing the current state of this :class:`DistributedFusedAdam` instance.
Example::
checkpoint = {}
checkpoint['model'] = model.state_dict()
checkpoint['optimizer'] = optimizer.state_dict()
torch.save(checkpoint, "saved.pth")
"""
# save step, master weights and first/second moments
state_dict
=
{}
state_dict
[
'step'
]
=
self
.
_step
state_dict
[
'fp32_p'
]
=
self
.
_fp32_p
state_dict
[
'fp32_m'
]
=
self
.
_fp32_m
state_dict
[
'fp32_v'
]
=
self
.
_fp32_v
return
state_dict
def
load_state_dict
(
self
,
state_dict
):
"""
Loads a state_dict created by an earlier call to state_dict().
If an DistributedFusedAdam instance was constructed from some ``init_optimizer``,
whose parameters in turn came from ``model``, it is expected that the user
will call ``model.load_state_dict()`` before
``optimizer.load_state_dict()`` is called.
Example::
model = torch.nn.Linear(D_in, D_out).cuda().half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
...
checkpoint = torch.load("saved.pth")
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
"""
# restore step, master weights and first/second moments
self
.
_step
=
state_dict
[
'step'
]
self
.
_fp32_p
=
state_dict
[
'fp32_p'
].
to
(
device
=
"cuda"
)
self
.
_fp32_m
=
state_dict
[
'fp32_m'
].
to
(
device
=
"cuda"
)
self
.
_fp32_v
=
state_dict
[
'fp32_v'
].
to
(
device
=
"cuda"
)
self
.
_resume_from_checkpoint
=
True
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment